Skip to content

Commit

Permalink
Add global timeout for all operations within connection
Browse files Browse the repository at this point in the history
  • Loading branch information
decaz committed Oct 3, 2019
1 parent 4f0a301 commit d36ec87
Show file tree
Hide file tree
Showing 8 changed files with 96 additions and 48 deletions.
18 changes: 12 additions & 6 deletions aio_pika/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ def channel(self) -> aiormq.Channel:
def number(self):
return self.channel.number if self._channel else None

def _get_operation_timeout(self, timeout: TimeoutType = None):
if timeout is not None:
return timeout
return self._connection.operation_timeout

def __str__(self):
return "{0}".format(
self.number or "Not initialized channel"
Expand Down Expand Up @@ -163,7 +168,8 @@ async def initialize(self, timeout: TimeoutType = None) -> None:
raise RuntimeError("Can't initialize channel")

self._channel = await asyncio.wait_for(
self._create_channel(), timeout=timeout
self._create_channel(),
timeout=self._get_operation_timeout(timeout)
)

self._delivery_tag = 0
Expand Down Expand Up @@ -247,7 +253,7 @@ async def declare_queue(
"""

queue = self.QUEUE_CLASS(
connection=self,
connection=self._connection,
channel=self.channel,
name=name,
durable=durable,
Expand All @@ -271,7 +277,7 @@ async def set_qos(
prefetch_count=prefetch_count,
prefetch_size=prefetch_size
),
timeout=timeout
timeout=self._get_operation_timeout(timeout)
)

async def queue_delete(
Expand All @@ -286,7 +292,7 @@ async def queue_delete(
if_empty=if_empty,
nowait=nowait,
),
timeout=timeout
timeout=self._get_operation_timeout(timeout)
)

async def exchange_delete(
Expand All @@ -300,15 +306,15 @@ async def exchange_delete(
if_unused=if_unused,
nowait=nowait,
),
timeout=timeout
timeout=self._get_operation_timeout(timeout)
)

def transaction(self) -> Transaction:
if self._publisher_confirms:
raise RuntimeError("Cannot create transaction when publisher "
"confirms are enabled")

return Transaction(self._channel)
return Transaction(connection=self._connection, channel=self._channel)

async def flow(self, active: bool = True) -> aiormq.spec.Channel.FlowOk:
return await self.channel.flow(active=active)
Expand Down
12 changes: 10 additions & 2 deletions aio_pika/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,11 @@ def _parse_kwargs(cls, kwargs):
result[key] = parser(kwargs.get(key, default))
return result

def __init__(self, url, loop=None, **kwargs):
def __init__(self, url, operation_timeout: TimeoutType = None, loop=None,
**kwargs):
self.loop = loop or asyncio.get_event_loop()
self.url = URL(url)
self.operation_timeout = operation_timeout

self.kwargs = self._parse_kwargs(kwargs or self.url.query)

Expand Down Expand Up @@ -217,6 +219,7 @@ async def connect(
login: str = 'guest', password: str = 'guest', virtualhost: str = '/',
ssl: bool = False, loop: asyncio.AbstractEventLoop = None,
ssl_options: dict = None, timeout: TimeoutType = None,
operation_timeout: TimeoutType = None,
connection_class: Type[ConnectionType] = Connection, **kwargs
) -> ConnectionType:

Expand Down Expand Up @@ -267,6 +270,7 @@ async def main():
:param ssl: use SSL for connection. Should be used with addition kwargs.
:param ssl_options: A dict of values for the SSL connection.
:param timeout: connection timeout in seconds
:param operation_timeout: execution timeout in seconds
:param loop:
Event loop (:func:`asyncio.get_event_loop()` when :class:`None`)
:param connection_class: Factory of a new connection
Expand Down Expand Up @@ -294,7 +298,11 @@ async def main():
query=kw
)

connection = connection_class(url, loop=loop)
connection = connection_class(
url,
operation_timeout=operation_timeout,
loop=loop
)
await connection.connect(timeout=timeout)
return connection

Expand Down
16 changes: 11 additions & 5 deletions aio_pika/exchange.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(self, connection, channel: aiormq.Channel, name: str,
if not arguments:
arguments = {}

self._connection = connection
self._channel = channel
self.__type = type.value
self.name = name
Expand All @@ -52,6 +53,11 @@ def channel(self) -> aiormq.Channel:

return self._channel

def _get_operation_timeout(self, timeout: TimeoutType = None):
if timeout is not None:
return timeout
return self._connection.operation_timeout

def __str__(self):
return self.name

Expand All @@ -71,7 +77,7 @@ async def declare(
internal=self.internal,
passive=self.passive,
arguments=self.arguments,
), timeout=timeout)
), timeout=self._get_operation_timeout(timeout))

@staticmethod
def _get_exchange_name(exchange: ExchangeType_):
Expand Down Expand Up @@ -133,7 +139,7 @@ async def bind(
destination=self.name,
routing_key=routing_key,
source=self._get_exchange_name(exchange),
), timeout=timeout
), timeout=self._get_operation_timeout(timeout)
)

async def unbind(
Expand Down Expand Up @@ -163,7 +169,7 @@ async def unbind(
destination=self.name,
routing_key=routing_key,
source=self._get_exchange_name(exchange),
), timeout=timeout
), timeout=self._get_operation_timeout(timeout)
)

async def publish(
Expand Down Expand Up @@ -197,7 +203,7 @@ async def publish(
properties=message.properties,
mandatory=mandatory,
immediate=immediate
), timeout=timeout
), timeout=self._get_operation_timeout(timeout)
)

async def delete(
Expand All @@ -213,7 +219,7 @@ async def delete(
log.info("Deleting %r", self)
return await asyncio.wait_for(
self.channel.exchange_delete(self.name, if_unused=if_unused),
timeout=timeout
timeout=self._get_operation_timeout(timeout)
)


Expand Down
47 changes: 29 additions & 18 deletions aio_pika/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from .exceptions import QueueEmpty
from .exchange import Exchange
from aio_pika.types import ExchangeType as ExchangeType_
from .types import ExchangeType as ExchangeType_, TimeoutType
from .message import IncomingMessage
from .tools import create_task, shield

Expand All @@ -36,6 +36,7 @@ def __init__(self, connection, channel: aiormq.Channel, name,

self.loop = connection.loop

self._connection = connection
self._channel = channel
self.name = name or ''
self.durable = durable
Expand All @@ -52,6 +53,11 @@ def channel(self) -> aiormq.Channel:
raise RuntimeError("Channel not opened")
return self._channel

def _get_operation_timeout(self, timeout: TimeoutType = None):
if timeout is not None:
return timeout
return self._connection.operation_timeout

def __str__(self):
return "%s" % self.name

Expand All @@ -70,7 +76,9 @@ def __repr__(self):
self.arguments,
)

async def declare(self, timeout: int=None) -> aiormq.spec.Queue.DeclareOk:
async def declare(
self, timeout: TimeoutType = None
) -> aiormq.spec.Queue.DeclareOk:
""" Declare queue.
:param timeout: execution timeout
Expand All @@ -84,15 +92,15 @@ async def declare(self, timeout: int=None) -> aiormq.spec.Queue.DeclareOk:
queue=self.name, durable=self.durable,
exclusive=self.exclusive, auto_delete=self.auto_delete,
arguments=self.arguments, passive=self.passive,
), timeout=timeout
), timeout=self._get_operation_timeout(timeout)
) # type: aiormq.spec.Queue.DeclareOk

self.name = self.declaration_result.queue
return self.declaration_result

async def bind(
self, exchange: ExchangeType_, routing_key: str=None, *,
arguments=None, timeout: int=None
arguments=None, timeout: TimeoutType = None
) -> aiormq.spec.Queue.BindOk:

""" A binding is a relationship between an exchange and a queue.
Expand Down Expand Up @@ -126,12 +134,12 @@ async def bind(
exchange=Exchange._get_exchange_name(exchange),
routing_key=routing_key,
arguments=arguments
), timeout=timeout
), timeout=self._get_operation_timeout(timeout)
)

async def unbind(
self, exchange: ExchangeType_, routing_key: str=None,
arguments: dict=None, timeout: int=None
arguments: dict=None, timeout: TimeoutType = None
) -> aiormq.spec.Queue.UnbindOk:

""" Remove binding from exchange for this :class:`Queue` instance
Expand Down Expand Up @@ -159,13 +167,13 @@ async def unbind(
exchange=Exchange._get_exchange_name(exchange),
routing_key=routing_key,
arguments=arguments
), timeout=timeout
), timeout=self._get_operation_timeout(timeout)
)

async def consume(
self, callback: Callable[[IncomingMessage], Any], no_ack: bool = False,
exclusive: bool = False, arguments: dict = None,
consumer_tag=None, timeout=None
consumer_tag=None, timeout: TimeoutType = None
) -> ConsumerTag:

""" Start to consuming the :class:`Queue`.
Expand Down Expand Up @@ -203,11 +211,13 @@ async def consume(
arguments=arguments,
consumer_tag=consumer_tag,
),
timeout=timeout
timeout=self._get_operation_timeout(timeout)
)).consumer_tag

async def cancel(self, consumer_tag: ConsumerTag, timeout=None,
nowait: bool=False) -> aiormq.spec.Basic.CancelOk:
async def cancel(
self, consumer_tag: ConsumerTag, timeout: TimeoutType = None,
nowait: bool=False
) -> aiormq.spec.Basic.CancelOk:
""" This method cancels a consumer. This does not affect already
delivered messages, but it does mean the server will not send any more
messages for that consumer. The client may receive an arbitrary number
Expand All @@ -230,7 +240,7 @@ async def cancel(self, consumer_tag: ConsumerTag, timeout=None,
consumer_tag=consumer_tag,
nowait=nowait
),
timeout=timeout
timeout=self._get_operation_timeout(timeout)
)

async def get(
Expand All @@ -249,7 +259,7 @@ async def get(

msg = await asyncio.wait_for(self.channel.basic_get(
self.name, no_ack=no_ack
), timeout=timeout
), timeout=self._get_operation_timeout(timeout)
) # type: Optional[DeliveredMessage]

if msg is None:
Expand All @@ -260,7 +270,7 @@ async def get(
return IncomingMessage(msg, no_ack=no_ack)

async def purge(
self, no_wait=False, timeout=None
self, no_wait=False, timeout: TimeoutType = None
) -> aiormq.spec.Queue.PurgeOk:
""" Purge all messages from the queue.
Expand All @@ -275,11 +285,12 @@ async def purge(
self.channel.queue_purge(
self.name,
nowait=no_wait,
), timeout=timeout
), timeout=self._get_operation_timeout(timeout)
)

async def delete(self, *, if_unused=True, if_empty=True,
timeout=None) -> aiormq.spec.Queue.DeclareOk:
async def delete(
self, *, if_unused=True, if_empty=True, timeout: TimeoutType = None
) -> aiormq.spec.Queue.DeclareOk:

""" Delete the queue.
Expand All @@ -294,7 +305,7 @@ async def delete(self, *, if_unused=True, if_empty=True,
return await asyncio.wait_for(
self.channel.queue_delete(
self.name, if_unused=if_unused, if_empty=if_empty
), timeout=timeout
), timeout=self._get_operation_timeout(timeout)
)

def __aiter__(self) -> 'QueueIterator':
Expand Down
5 changes: 4 additions & 1 deletion aio_pika/robust_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ async def connect_robust(
login: str = 'guest', password: str = 'guest', virtualhost: str = '/',
ssl: bool = False, loop: asyncio.AbstractEventLoop = None,
ssl_options: dict = None, timeout: TimeoutType = None,
operation_timeout: TimeoutType = None,
connection_class: Type[ConnectionType] = RobustConnection, **kwargs
) -> ConnectionType:

Expand Down Expand Up @@ -209,6 +210,7 @@ async def main():
:param ssl: use SSL for connection. Should be used with addition kwargs.
:param ssl_options: A dict of values for the SSL connection.
:param timeout: connection timeout in seconds
:param operation_timeout: execution timeout in seconds
:param loop:
Event loop (:func:`asyncio.get_event_loop()` when :class:`None`)
:param connection_class: Factory of a new connection
Expand All @@ -224,7 +226,8 @@ async def main():
url=url, host=host, port=port, login=login,
password=password, virtualhost=virtualhost, ssl=ssl,
loop=loop, connection_class=connection_class,
ssl_options=ssl_options, timeout=timeout, **kwargs
ssl_options=ssl_options, timeout=timeout,
operation_timeout=operation_timeout, **kwargs
)
)

Expand Down
5 changes: 3 additions & 2 deletions aio_pika/robust_exchange.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from .exchange import Exchange, ExchangeType
from .channel import Channel
from .types import TimeoutType


log = getLogger(__name__)
Expand Down Expand Up @@ -42,7 +43,7 @@ async def on_reconnect(self, channel: Channel):
await self.bind(exchange, **kwargs)

async def bind(self, exchange, routing_key: str='', *,
arguments=None, timeout: int=None):
arguments=None, timeout: TimeoutType = None):
result = await super().bind(
exchange, routing_key=routing_key,
arguments=arguments, timeout=timeout
Expand All @@ -55,7 +56,7 @@ async def bind(self, exchange, routing_key: str='', *,
return result

async def unbind(self, exchange, routing_key: str = '',
arguments: dict=None, timeout: int=None):
arguments: dict=None, timeout: TimeoutType = None):

result = await super().unbind(exchange, routing_key,
arguments=arguments, timeout=timeout)
Expand Down
Loading

0 comments on commit d36ec87

Please sign in to comment.