Skip to content

Commit

Permalink
Merge pull request #377 from cloud-rocket/add-custom-rpc-exchange
Browse files Browse the repository at this point in the history
new: add custom exchanges to rpc pattern
  • Loading branch information
mosquito authored Sep 20, 2023
2 parents c8dae1e + 62a3268 commit b8e8eff
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 13 deletions.
55 changes: 42 additions & 13 deletions aio_pika/patterns/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def multiply(*, x, y):
result_queue: AbstractQueue
result_consumer_tag: ConsumerTag
dlx_exchange: AbstractExchange
rpc_exchange: Optional[AbstractExchange]

def __init__(
self, channel: AbstractChannel,
Expand Down Expand Up @@ -115,6 +116,13 @@ def create_future(self) -> Tuple[asyncio.Future, str]:
future.add_done_callback(self.__remove_future(correlation_id))
return future, correlation_id

def _format_routing_key(self, method_name: str) -> str:
return (
f'{self.rpc_exchange.name}::{method_name}'
if self.rpc_exchange
else method_name
)

async def close(self) -> None:
if not hasattr(self, "result_queue"):
log.warning("RPC already closed")
Expand Down Expand Up @@ -142,13 +150,22 @@ async def close(self) -> None:
del self.result_queue
del self.dlx_exchange

if self.rpc_exchange:
del self.rpc_exchange

async def initialize(
self, auto_delete: bool = True,
durable: bool = False, **kwargs: Any,
durable: bool = False, exchange: str = '', **kwargs: Any,
) -> None:
if hasattr(self, "result_queue"):
return

self.rpc_exchange = await self.channel.declare_exchange(
exchange,
type=ExchangeType.DIRECT,
auto_delete=True,
durable=durable) if exchange else None

self.result_queue = await self.channel.declare_queue(
None, auto_delete=auto_delete, durable=durable, **kwargs,
)
Expand Down Expand Up @@ -252,14 +269,16 @@ async def on_result_message(self, message: AbstractIncomingMessage) -> None:
async def on_call_message(
self, method_name: str, message: IncomingMessage,
) -> None:
if method_name not in self.routes:

routing_key = self._format_routing_key(method_name)

if routing_key not in self.routes:
log.warning("Method %r not registered in %r", method_name, self)
return

try:
payload = await self.deserialize_message(message)
func = self.routes[method_name]

func = self.routes[routing_key]
result: Any = await self.execute(func, payload)
message_type = RPCMessageType.RESULT
except Exception as e:
Expand Down Expand Up @@ -376,12 +395,14 @@ async def call(
if expiration is not None:
message.expiration = expiration

log.debug("Publishing calls for %s(%r)", method_name, kwargs)
await self.channel.default_exchange.publish(
message, routing_key=method_name, mandatory=True,
)
routing_key = self._format_routing_key(method_name)

log.debug("Publishing calls for %s(%r)", routing_key, kwargs)
exchange = self.rpc_exchange or self.channel.default_exchange
await exchange.publish(message, routing_key=routing_key,
mandatory=True)

log.debug("Waiting RPC result for %s(%r)", method_name, kwargs)
log.debug("Waiting RPC result for %s(%r)", routing_key, kwargs)
return await future

async def register(
Expand All @@ -405,21 +426,29 @@ async def register(

kwargs["arguments"] = arguments

queue = await self.channel.declare_queue(method_name, **kwargs)
routing_key = self._format_routing_key(method_name)

queue = await self.channel.declare_queue(routing_key, **kwargs)

if self.rpc_exchange:
await queue.bind(
self.rpc_exchange,
routing_key
)

if func in self.consumer_tags:
raise RuntimeError("Function already registered")

if method_name in self.routes:
if routing_key in self.routes:
raise RuntimeError(
"Method name already used for %r" % self.routes[method_name],
"Method name already used for %r" % self.routes[routing_key],
)

self.consumer_tags[func] = await queue.consume(
partial(self.on_call_message, method_name),
)

self.routes[method_name] = func
self.routes[routing_key] = func
self.queues[func] = queue

async def unregister(self, func: CallbackType) -> None:
Expand Down
39 changes: 39 additions & 0 deletions tests/test_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@ async def rpc_func(*, foo, bar):
return {"foo": "bar"}


async def rpc_func2(*, foo, bar):
assert not foo
assert not bar

return {"foo": "bar2"}


class TestCase:
async def test_simple(self, channel: aio_pika.Channel):
rpc = await RPC.create(channel, auto_delete=True)
Expand Down Expand Up @@ -238,3 +245,35 @@ async def inner():

with pytest.raises(TypeError):
await rpc.call("test.not-serializable")

async def test_custom_exchange(self, channel: aio_pika.Channel):
rpc_ex1 = await RPC.create(channel, auto_delete=True, exchange='ex1')
rpc_ex2 = await RPC.create(channel, auto_delete=True, exchange='ex2')
rpc_default = await RPC.create(channel, auto_delete=True)

await rpc_ex1.register("test.rpc", rpc_func, auto_delete=True)
result = await rpc_ex1.proxy.test.rpc(foo=None, bar=None)
assert result == {"foo": "bar"}

with pytest.raises(MessageProcessError):
await rpc_ex2.proxy.test.rpc(foo=None, bar=None)

await rpc_ex2.register("test.rpc", rpc_func2, auto_delete=True)
result = await rpc_ex2.proxy.test.rpc(foo=None, bar=None)
assert result == {"foo": "bar2"}

with pytest.raises(MessageProcessError):
await rpc_default.proxy.test.rpc(foo=None, bar=None)

await rpc_default.register("test.rpc", rpc_func, auto_delete=True)
result = await rpc_default.proxy.test.rpc(foo=None, bar=None)
assert result == {"foo": "bar"}

await rpc_ex1.unregister(rpc_func)
await rpc_ex1.close()

await rpc_ex2.unregister(rpc_func2)
await rpc_ex2.close()

await rpc_default.unregister(rpc_func)
await rpc_default.close()

0 comments on commit b8e8eff

Please sign in to comment.