diff --git a/aio_pika/patterns/rpc.py b/aio_pika/patterns/rpc.py index 000aaf42..b6bd188f 100644 --- a/aio_pika/patterns/rpc.py +++ b/aio_pika/patterns/rpc.py @@ -115,6 +115,11 @@ def create_future(self) -> Tuple[asyncio.Future, str]: future.add_done_callback(self.__remove_future(correlation_id)) return future, correlation_id + def _routing_key(self, method_name): + return f'{self.rpc_exchange.name}::{method_name}' \ + if self.rpc_exchange \ + else method_name + async def close(self) -> None: if not hasattr(self, "result_queue"): log.warning("RPC already closed") @@ -142,13 +147,22 @@ async def close(self) -> None: del self.result_queue del self.dlx_exchange + if self.rpc_exchange: + del self.rpc_exchange + async def initialize( self, auto_delete: bool = True, - durable: bool = False, **kwargs: Any, + durable: bool = False, exchange: str ='', **kwargs: Any, ) -> None: if hasattr(self, "result_queue"): return + self.rpc_exchange = await self.channel.declare_exchange( + exchange, + type=ExchangeType.DIRECT, + auto_delete=True, + durable=durable) if exchange else None + self.result_queue = await self.channel.declare_queue( None, auto_delete=auto_delete, durable=durable, **kwargs, ) @@ -252,14 +266,16 @@ async def on_result_message(self, message: AbstractIncomingMessage) -> None: async def on_call_message( self, method_name: str, message: IncomingMessage, ) -> None: - if method_name not in self.routes: + + routing_key = self._routing_key(method_name) + + if routing_key not in self.routes: log.warning("Method %r not registered in %r", method_name, self) return try: payload = await self.deserialize_message(message) - func = self.routes[method_name] - + func = self.routes[routing_key] result: Any = await self.execute(func, payload) message_type = RPCMessageType.RESULT except Exception as e: @@ -376,12 +392,19 @@ async def call( if expiration is not None: message.expiration = expiration - log.debug("Publishing calls for %s(%r)", method_name, kwargs) - await self.channel.default_exchange.publish( - message, routing_key=method_name, mandatory=True, - ) + routing_key = self._routing_key(method_name) - log.debug("Waiting RPC result for %s(%r)", method_name, kwargs) + log.debug("Publishing calls for %s(%r)", routing_key, kwargs) + if self.rpc_exchange: + await self.rpc_exchange.publish( + message, routing_key=routing_key, mandatory=True, + ) + else: + await self.channel.default_exchange.publish( + message, routing_key=routing_key, mandatory=True, + ) + + log.debug("Waiting RPC result for %s(%r)", routing_key, kwargs) return await future async def register( @@ -405,21 +428,29 @@ async def register( kwargs["arguments"] = arguments - queue = await self.channel.declare_queue(method_name, **kwargs) + routing_key = self._routing_key(method_name) + + queue = await self.channel.declare_queue(routing_key, **kwargs) + + if self.rpc_exchange: + await queue.bind( + self.rpc_exchange, + routing_key + ) if func in self.consumer_tags: raise RuntimeError("Function already registered") - if method_name in self.routes: + if routing_key in self.routes: raise RuntimeError( - "Method name already used for %r" % self.routes[method_name], + "Method name already used for %r" % self.routes[routing_key], ) self.consumer_tags[func] = await queue.consume( partial(self.on_call_message, method_name), ) - self.routes[method_name] = func + self.routes[routing_key] = func self.queues[func] = queue async def unregister(self, func: CallbackType) -> None: diff --git a/tests/test_rpc.py b/tests/test_rpc.py index 7b7ed156..c931a6cf 100644 --- a/tests/test_rpc.py +++ b/tests/test_rpc.py @@ -21,6 +21,13 @@ async def rpc_func(*, foo, bar): return {"foo": "bar"} +async def rpc_func2(*, foo, bar): + assert not foo + assert not bar + + return {"foo": "bar2"} + + class TestCase: async def test_simple(self, channel: aio_pika.Channel): rpc = await RPC.create(channel, auto_delete=True) @@ -238,3 +245,35 @@ async def inner(): with pytest.raises(TypeError): await rpc.call("test.not-serializable") + + async def test_custom_exchange(self, channel: aio_pika.Channel): + rpc_ex1 = await RPC.create(channel, auto_delete=True, exchange='ex1') + rpc_ex2 = await RPC.create(channel, auto_delete=True, exchange='ex2') + rpc_default = await RPC.create(channel, auto_delete=True) + + await rpc_ex1.register("test.rpc", rpc_func, auto_delete=True) + result = await rpc_ex1.proxy.test.rpc(foo=None, bar=None) + assert result == {"foo": "bar"} + + with pytest.raises(MessageProcessError): + await rpc_ex2.proxy.test.rpc(foo=None, bar=None) + + await rpc_ex2.register("test.rpc", rpc_func2, auto_delete=True) + result = await rpc_ex2.proxy.test.rpc(foo=None, bar=None) + assert result == {"foo": "bar2"} + + with pytest.raises(MessageProcessError): + await rpc_default.proxy.test.rpc(foo=None, bar=None) + + await rpc_default.register("test.rpc", rpc_func, auto_delete=True) + result = await rpc_default.proxy.test.rpc(foo=None, bar=None) + assert result == {"foo": "bar"} + + await rpc_ex1.unregister(rpc_func) + await rpc_ex1.close() + + await rpc_ex2.unregister(rpc_func2) + await rpc_ex2.close() + + await rpc_default.unregister(rpc_func) + await rpc_default.close()