Skip to content

Commit

Permalink
new: add custom exchanges to rpc pattern
Browse files Browse the repository at this point in the history
  • Loading branch information
cloud-rocket committed Sep 16, 2021
1 parent d5993fb commit c237f93
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 12 deletions.
58 changes: 46 additions & 12 deletions aio_pika/patterns/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,17 @@ def __init__(self, channel: Channel):
self.queues = {}
self.consumer_tags = {}
self.dlx_exchange = None
self.rpc_exchange = None

def __remove_future(self, future: asyncio.Future):
log.debug("Remove done future %r", future)
self.futures.pop(id(future), None)

def _routing_key(self, method_name):
return '{}::{}'.format(self.rpc_exchange.name, method_name) \
if self.rpc_exchange \
else method_name

def create_future(self) -> asyncio.Future:
future = self.loop.create_future()
log.debug("Create future for RPC call")
Expand Down Expand Up @@ -126,10 +132,21 @@ async def close(self):
self.result_queue = None

@shield
async def initialize(self, auto_delete=True, durable=False, **kwargs):
async def initialize(self,
auto_delete=True,
durable=False,
exchange='',
**kwargs):
if self.result_queue is not None:
return

if exchange:
self.rpc_exchange = await self.channel.declare_exchange(
exchange,
type=ExchangeType.DIRECT,
auto_delete=True,
durable=durable)

self.result_queue = await self.channel.declare_queue(
None, auto_delete=auto_delete, durable=durable, **kwargs
)
Expand Down Expand Up @@ -222,13 +239,15 @@ async def on_result_message(self, message: IncomingMessage):
async def on_call_message(
self, method_name: str, message: IncomingMessage
):
if method_name not in self.routes:
routing_key = self._routing_key(method_name)

if routing_key not in self.routes:
log.warning("Method %r not registered in %r", method_name, self)
return

try:
payload = self.deserialize(message.body)
func = self.routes[method_name]
func = self.routes[routing_key]

result = await self.execute(func, payload)
result = self.serialize(result)
Expand Down Expand Up @@ -341,12 +360,19 @@ async def call(
if expiration is not None:
message.expiration = expiration

log.debug("Publishing calls for %s(%r)", method_name, kwargs)
await self.channel.default_exchange.publish(
message, routing_key=method_name, mandatory=True,
)
routing_key = self._routing_key(method_name)

log.debug("Publishing calls for %s(%r)", routing_key, kwargs)
if self.rpc_exchange:
await self.rpc_exchange.publish(
message, routing_key=routing_key, mandatory=True,
)
else:
await self.channel.default_exchange.publish(
message, routing_key=routing_key, mandatory=True,
)

log.debug("Waiting RPC result for %s(%r)", method_name, kwargs)
log.debug("Waiting RPC result for %s(%r)", routing_key, kwargs)
return await future

async def register(self, method_name, func: CallbackType, **kwargs):
Expand All @@ -366,21 +392,29 @@ async def register(self, method_name, func: CallbackType, **kwargs):

kwargs["arguments"] = arguments

queue = await self.channel.declare_queue(method_name, **kwargs)
routing_key = self._routing_key(method_name)

queue = await self.channel.declare_queue(routing_key, **kwargs)

if self.rpc_exchange:
await queue.bind(
self.rpc_exchange,
routing_key
)

if func in self.consumer_tags:
raise RuntimeError("Function already registered")

if method_name in self.routes:
if routing_key in self.routes:
raise RuntimeError(
"Method name already used for %r" % self.routes[method_name],
"Method name already used for %r" % self.routes[routing_key],
)

self.consumer_tags[func] = await queue.consume(
partial(self.on_call_message, method_name),
)

self.routes[method_name] = awaitable(func)
self.routes[routing_key] = awaitable(func)
self.queues[func] = queue

async def unregister(self, func):
Expand Down
39 changes: 39 additions & 0 deletions tests/test_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@ def rpc_func(*, foo, bar):
return {"foo": "bar"}


def rpc_func2(*, foo, bar):
assert not foo
assert not bar

return {"foo": "bar2"}


class TestCase:
async def test_simple(self, channel: aio_pika.Channel):
rpc = await RPC.create(channel, auto_delete=True)
Expand Down Expand Up @@ -172,3 +179,35 @@ async def test_register_twice(self, channel: aio_pika.Channel):
await rpc.unregister(rpc_func)

await rpc.close()

async def test_custom_exchange(self, channel: aio_pika.Channel):
rpc_ex1 = await RPC.create(channel, auto_delete=True, exchange='ex1')
rpc_ex2 = await RPC.create(channel, auto_delete=True, exchange='ex2')
rpc_default = await RPC.create(channel, auto_delete=True)

await rpc_ex1.register("test.rpc", rpc_func, auto_delete=True)
result = await rpc_ex1.proxy.test.rpc(foo=None, bar=None)
assert result == {"foo": "bar"}

with pytest.raises(DeliveryError):
await rpc_ex2.proxy.test.rpc(foo=None, bar=None)

await rpc_ex2.register("test.rpc", rpc_func2, auto_delete=True)
result = await rpc_ex2.proxy.test.rpc(foo=None, bar=None)
assert result == {"foo": "bar2"}

with pytest.raises(DeliveryError):
await rpc_default.proxy.test.rpc(foo=None, bar=None)

await rpc_default.register("test.rpc", rpc_func, auto_delete=True)
result = await rpc_default.proxy.test.rpc(foo=None, bar=None)
assert result == {"foo": "bar"}

await rpc_ex1.unregister(rpc_func)
await rpc_ex1.close()

await rpc_ex2.unregister(rpc_func2)
await rpc_ex2.close()

await rpc_default.unregister(rpc_func)
await rpc_default.close()

0 comments on commit c237f93

Please sign in to comment.