From 2adce0caa8d2436a489910186cc1edfe34626fee Mon Sep 17 00:00:00 2001 From: Meir Tseitlin Date: Fri, 12 Feb 2021 12:22:43 -0600 Subject: [PATCH] new: add custom exchanges to rpc pattern --- aio_pika/patterns/rpc.py | 55 ++++++++++++++++++++++++++++++---------- tests/test_rpc.py | 39 ++++++++++++++++++++++++++++ 2 files changed, 81 insertions(+), 13 deletions(-) diff --git a/aio_pika/patterns/rpc.py b/aio_pika/patterns/rpc.py index 000aaf42..95b13267 100644 --- a/aio_pika/patterns/rpc.py +++ b/aio_pika/patterns/rpc.py @@ -85,6 +85,7 @@ def multiply(*, x, y): result_queue: AbstractQueue result_consumer_tag: ConsumerTag dlx_exchange: AbstractExchange + rpc_exchange: Optional[AbstractExchange] def __init__( self, channel: AbstractChannel, @@ -115,6 +116,13 @@ def create_future(self) -> Tuple[asyncio.Future, str]: future.add_done_callback(self.__remove_future(correlation_id)) return future, correlation_id + def _format_routing_key(self, method_name: str) -> str: + return ( + f'{self.rpc_exchange.name}::{method_name}' + if self.rpc_exchange + else method_name + ) + async def close(self) -> None: if not hasattr(self, "result_queue"): log.warning("RPC already closed") @@ -142,13 +150,22 @@ async def close(self) -> None: del self.result_queue del self.dlx_exchange + if self.rpc_exchange: + del self.rpc_exchange + async def initialize( self, auto_delete: bool = True, - durable: bool = False, **kwargs: Any, + durable: bool = False, exchange: str = '', **kwargs: Any, ) -> None: if hasattr(self, "result_queue"): return + self.rpc_exchange = await self.channel.declare_exchange( + exchange, + type=ExchangeType.DIRECT, + auto_delete=True, + durable=durable) if exchange else None + self.result_queue = await self.channel.declare_queue( None, auto_delete=auto_delete, durable=durable, **kwargs, ) @@ -252,14 +269,16 @@ async def on_result_message(self, message: AbstractIncomingMessage) -> None: async def on_call_message( self, method_name: str, message: IncomingMessage, ) -> None: - if method_name not in self.routes: + + routing_key = self._format_routing_key(method_name) + + if routing_key not in self.routes: log.warning("Method %r not registered in %r", method_name, self) return try: payload = await self.deserialize_message(message) - func = self.routes[method_name] - + func = self.routes[routing_key] result: Any = await self.execute(func, payload) message_type = RPCMessageType.RESULT except Exception as e: @@ -376,12 +395,14 @@ async def call( if expiration is not None: message.expiration = expiration - log.debug("Publishing calls for %s(%r)", method_name, kwargs) - await self.channel.default_exchange.publish( - message, routing_key=method_name, mandatory=True, - ) + routing_key = self._format_routing_key(method_name) + + log.debug("Publishing calls for %s(%r)", routing_key, kwargs) + exchange = self.rpc_exchange or self.channel.default_exchange + await exchange.publish(message, routing_key=routing_key, + mandatory=True) - log.debug("Waiting RPC result for %s(%r)", method_name, kwargs) + log.debug("Waiting RPC result for %s(%r)", routing_key, kwargs) return await future async def register( @@ -405,21 +426,29 @@ async def register( kwargs["arguments"] = arguments - queue = await self.channel.declare_queue(method_name, **kwargs) + routing_key = self._format_routing_key(method_name) + + queue = await self.channel.declare_queue(routing_key, **kwargs) + + if self.rpc_exchange: + await queue.bind( + self.rpc_exchange, + routing_key + ) if func in self.consumer_tags: raise RuntimeError("Function already registered") - if method_name in self.routes: + if routing_key in self.routes: raise RuntimeError( - "Method name already used for %r" % self.routes[method_name], + "Method name already used for %r" % self.routes[routing_key], ) self.consumer_tags[func] = await queue.consume( partial(self.on_call_message, method_name), ) - self.routes[method_name] = func + self.routes[routing_key] = func self.queues[func] = queue async def unregister(self, func: CallbackType) -> None: diff --git a/tests/test_rpc.py b/tests/test_rpc.py index 7b7ed156..c931a6cf 100644 --- a/tests/test_rpc.py +++ b/tests/test_rpc.py @@ -21,6 +21,13 @@ async def rpc_func(*, foo, bar): return {"foo": "bar"} +async def rpc_func2(*, foo, bar): + assert not foo + assert not bar + + return {"foo": "bar2"} + + class TestCase: async def test_simple(self, channel: aio_pika.Channel): rpc = await RPC.create(channel, auto_delete=True) @@ -238,3 +245,35 @@ async def inner(): with pytest.raises(TypeError): await rpc.call("test.not-serializable") + + async def test_custom_exchange(self, channel: aio_pika.Channel): + rpc_ex1 = await RPC.create(channel, auto_delete=True, exchange='ex1') + rpc_ex2 = await RPC.create(channel, auto_delete=True, exchange='ex2') + rpc_default = await RPC.create(channel, auto_delete=True) + + await rpc_ex1.register("test.rpc", rpc_func, auto_delete=True) + result = await rpc_ex1.proxy.test.rpc(foo=None, bar=None) + assert result == {"foo": "bar"} + + with pytest.raises(MessageProcessError): + await rpc_ex2.proxy.test.rpc(foo=None, bar=None) + + await rpc_ex2.register("test.rpc", rpc_func2, auto_delete=True) + result = await rpc_ex2.proxy.test.rpc(foo=None, bar=None) + assert result == {"foo": "bar2"} + + with pytest.raises(MessageProcessError): + await rpc_default.proxy.test.rpc(foo=None, bar=None) + + await rpc_default.register("test.rpc", rpc_func, auto_delete=True) + result = await rpc_default.proxy.test.rpc(foo=None, bar=None) + assert result == {"foo": "bar"} + + await rpc_ex1.unregister(rpc_func) + await rpc_ex1.close() + + await rpc_ex2.unregister(rpc_func2) + await rpc_ex2.close() + + await rpc_default.unregister(rpc_func) + await rpc_default.close()