diff --git a/aio_pika/patterns/rpc.py b/aio_pika/patterns/rpc.py index bbec383e..80a05c53 100644 --- a/aio_pika/patterns/rpc.py +++ b/aio_pika/patterns/rpc.py @@ -84,11 +84,17 @@ def __init__(self, channel: Channel): self.queues = {} self.consumer_tags = {} self.dlx_exchange = None + self.rpc_exchange = None def __remove_future(self, future: asyncio.Future): log.debug("Remove done future %r", future) self.futures.pop(id(future), None) + def _routing_key(self, method_name): + return '{}::{}'.format(self.rpc_exchange.name, method_name) \ + if self.rpc_exchange \ + else method_name + def create_future(self) -> asyncio.Future: future = self.loop.create_future() log.debug("Create future for RPC call") @@ -126,10 +132,21 @@ async def close(self): self.result_queue = None @shield - async def initialize(self, auto_delete=True, durable=False, **kwargs): + async def initialize(self, + auto_delete=True, + durable=False, + exchange='', + **kwargs): if self.result_queue is not None: return + if exchange: + self.rpc_exchange = await self.channel.declare_exchange( + exchange, + type=ExchangeType.DIRECT, + auto_delete=True, + durable=durable) + self.result_queue = await self.channel.declare_queue( None, auto_delete=auto_delete, durable=durable, **kwargs ) @@ -222,13 +239,15 @@ async def on_result_message(self, message: IncomingMessage): async def on_call_message( self, method_name: str, message: IncomingMessage ): - if method_name not in self.routes: + routing_key = self._routing_key(method_name) + + if routing_key not in self.routes: log.warning("Method %r not registered in %r", method_name, self) return try: payload = self.deserialize(message.body) - func = self.routes[method_name] + func = self.routes[routing_key] result = await self.execute(func, payload) result = self.serialize(result) @@ -341,12 +360,19 @@ async def call( if expiration is not None: message.expiration = expiration - log.debug("Publishing calls for %s(%r)", method_name, kwargs) - await self.channel.default_exchange.publish( - message, routing_key=method_name, mandatory=True, - ) + routing_key = self._routing_key(method_name) + + log.debug("Publishing calls for %s(%r)", routing_key, kwargs) + if self.rpc_exchange: + await self.rpc_exchange.publish( + message, routing_key=routing_key, mandatory=True, + ) + else: + await self.channel.default_exchange.publish( + message, routing_key=routing_key, mandatory=True, + ) - log.debug("Waiting RPC result for %s(%r)", method_name, kwargs) + log.debug("Waiting RPC result for %s(%r)", routing_key, kwargs) return await future async def register(self, method_name, func: CallbackType, **kwargs): @@ -366,21 +392,29 @@ async def register(self, method_name, func: CallbackType, **kwargs): kwargs["arguments"] = arguments - queue = await self.channel.declare_queue(method_name, **kwargs) + routing_key = self._routing_key(method_name) + + queue = await self.channel.declare_queue(routing_key, **kwargs) + + if self.rpc_exchange: + await queue.bind( + self.rpc_exchange, + routing_key + ) if func in self.consumer_tags: raise RuntimeError("Function already registered") - if method_name in self.routes: + if routing_key in self.routes: raise RuntimeError( - "Method name already used for %r" % self.routes[method_name], + "Method name already used for %r" % self.routes[routing_key], ) self.consumer_tags[func] = await queue.consume( partial(self.on_call_message, method_name), ) - self.routes[method_name] = awaitable(func) + self.routes[routing_key] = awaitable(func) self.queues[func] = queue async def unregister(self, func): diff --git a/tests/test_rpc.py b/tests/test_rpc.py index 8582489e..43accafe 100644 --- a/tests/test_rpc.py +++ b/tests/test_rpc.py @@ -19,6 +19,13 @@ def rpc_func(*, foo, bar): return {"foo": "bar"} +def rpc_func2(*, foo, bar): + assert not foo + assert not bar + + return {"foo": "bar2"} + + class TestCase: async def test_simple(self, channel: aio_pika.Channel): rpc = await RPC.create(channel, auto_delete=True) @@ -172,3 +179,35 @@ async def test_register_twice(self, channel: aio_pika.Channel): await rpc.unregister(rpc_func) await rpc.close() + + async def test_custom_exchange(self, channel: aio_pika.Channel): + rpc_ex1 = await RPC.create(channel, auto_delete=True, exchange='ex1') + rpc_ex2 = await RPC.create(channel, auto_delete=True, exchange='ex2') + rpc_default = await RPC.create(channel, auto_delete=True) + + await rpc_ex1.register("test.rpc", rpc_func, auto_delete=True) + result = await rpc_ex1.proxy.test.rpc(foo=None, bar=None) + assert result == {"foo": "bar"} + + with pytest.raises(DeliveryError): + await rpc_ex2.proxy.test.rpc(foo=None, bar=None) + + await rpc_ex2.register("test.rpc", rpc_func2, auto_delete=True) + result = await rpc_ex2.proxy.test.rpc(foo=None, bar=None) + assert result == {"foo": "bar2"} + + with pytest.raises(DeliveryError): + await rpc_default.proxy.test.rpc(foo=None, bar=None) + + await rpc_default.register("test.rpc", rpc_func, auto_delete=True) + result = await rpc_default.proxy.test.rpc(foo=None, bar=None) + assert result == {"foo": "bar"} + + await rpc_ex1.unregister(rpc_func) + await rpc_ex1.close() + + await rpc_ex2.unregister(rpc_func2) + await rpc_ex2.close() + + await rpc_default.unregister(rpc_func) + await rpc_default.close()