Skip to content

Commit

Permalink
Ensure async and/or versions called when necessary
Browse files Browse the repository at this point in the history
  • Loading branch information
hasier committed Mar 18, 2024
1 parent c97c138 commit d0f11cb
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 11 deletions.
30 changes: 25 additions & 5 deletions tenacity/asyncio/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,26 @@ class async_retry_base(retry_base):
async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override]
pass

def __and__(self, other: "typing.Union[retry_base, async_retry_base]") -> "retry_all": # type: ignore[override]
def __and__( # type: ignore[override]
self, other: "typing.Union[retry_base, async_retry_base]"
) -> "retry_all":
return retry_all(self, other)

def __or__(self, other: "typing.Union[retry_base, async_retry_base]") -> "retry_any": # type: ignore[override]
def __rand__( # type: ignore[misc,override]
self, other: "typing.Union[retry_base, async_retry_base]"
) -> "retry_all":
return retry_all(other, self)

def __or__( # type: ignore[override]
self, other: "typing.Union[retry_base, async_retry_base]"
) -> "retry_any":
return retry_any(self, other)

def __ror__( # type: ignore[misc,override]
self, other: "typing.Union[retry_base, async_retry_base]"
) -> "retry_any":
return retry_any(other, self)


class async_predicate_mixin:
async def __call__(self, retry_state: "RetryCallState") -> bool:
Expand All @@ -48,20 +62,26 @@ async def __call__(self, retry_state: "RetryCallState") -> bool:
return typing.cast(bool, result)


RetryBaseT = typing.Union[async_retry_base, typing.Callable[["RetryCallState"], typing.Awaitable[bool]]]
RetryBaseT = typing.Union[
async_retry_base, typing.Callable[["RetryCallState"], typing.Awaitable[bool]]
]


class retry_if_exception(async_predicate_mixin, _retry_if_exception, async_retry_base): # type: ignore[misc]
"""Retry strategy that retries if an exception verifies a predicate."""

def __init__(self, predicate: typing.Callable[[BaseException], typing.Awaitable[bool]]) -> None:
def __init__(
self, predicate: typing.Callable[[BaseException], typing.Awaitable[bool]]
) -> None:
super().__init__(predicate) # type: ignore[arg-type]


class retry_if_result(async_predicate_mixin, _retry_if_result, async_retry_base): # type: ignore[misc]
"""Retries if the result verifies a predicate."""

def __init__(self, predicate: typing.Callable[[typing.Any], typing.Awaitable[bool]]) -> None:
def __init__(
self, predicate: typing.Callable[[typing.Any], typing.Awaitable[bool]]
) -> None:
super().__init__(predicate) # type: ignore[arg-type]


Expand Down
10 changes: 8 additions & 2 deletions tenacity/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,16 @@ def __call__(self, retry_state: "RetryCallState") -> bool:
pass

def __and__(self, other: "retry_base") -> "retry_all":
return retry_all(self, other)
return other.__rand__(self)

def __rand__(self, other: "retry_base") -> "retry_all":
return retry_all(other, self)

def __or__(self, other: "retry_base") -> "retry_any":
return retry_any(self, other)
return other.__ror__(self)

def __ror__(self, other: "retry_base") -> "retry_any":
return retry_any(other, self)


RetryBaseT = typing.Union[retry_base, typing.Callable[["RetryCallState"], bool]]
Expand Down
115 changes: 111 additions & 4 deletions tests/test_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import tenacity
from tenacity import AsyncRetrying, RetryError
from tenacity import asyncio as tasyncio
from tenacity import retry, retry_if_result, stop_after_attempt
from tenacity import retry, retry_if_exception, retry_if_result, stop_after_attempt
from tenacity.wait import wait_fixed

from .test_tenacity import NoIOErrorAfterCount, current_time_ms
Expand Down Expand Up @@ -202,6 +202,59 @@ def lt_3(x: float) -> bool:

self.assertEqual(3, result)

@asynctest
async def test_retry_with_async_result(self):
async def test():
attempts = 0

async def lt_3(x: float) -> bool:
return x < 3

async for attempt in tasyncio.AsyncRetrying(
retry=tasyncio.retry_if_result(lt_3)
):
with attempt:
attempts += 1

assert attempt.retry_state.outcome # help mypy
if not attempt.retry_state.outcome.failed:
attempt.retry_state.set_result(attempts)

return attempts

result = await test()

self.assertEqual(3, result)

@asynctest
async def test_retry_with_async_exc(self):
async def test():
attempts = 0

class CustomException(Exception):
pass

async def is_exc(e: BaseException) -> bool:
return isinstance(e, CustomException)

async for attempt in tasyncio.AsyncRetrying(
retry=tasyncio.retry_if_exception(is_exc)
):
with attempt:
attempts += 1
if attempts < 3:
raise CustomException()

assert attempt.retry_state.outcome # help mypy
if not attempt.retry_state.outcome.failed:
attempt.retry_state.set_result(attempts)

return attempts

result = await test()

self.assertEqual(3, result)

@asynctest
async def test_retry_with_async_result_or(self):
async def test():
Expand All @@ -213,14 +266,45 @@ async def lt_3(x: float) -> bool:
class CustomException(Exception):
pass

def is_exc(e: BaseException) -> bool:
return isinstance(e, CustomException)

retry_strategy = tasyncio.retry_if_result(lt_3) | retry_if_exception(is_exc)
async for attempt in tasyncio.AsyncRetrying(retry=retry_strategy):
with attempt:
attempts += 1
if 2 < attempts < 4:
raise CustomException()

assert attempt.retry_state.outcome # help mypy
if not attempt.retry_state.outcome.failed:
attempt.retry_state.set_result(attempts)

return attempts

result = await test()

self.assertEqual(4, result)

@asynctest
async def test_retry_with_async_result_ror(self):
async def test():
attempts = 0

def lt_3(x: float) -> bool:
return x < 3

class CustomException(Exception):
pass

async def is_exc(e: BaseException) -> bool:
return isinstance(e, CustomException)

retry_strategy = tasyncio.retry_if_result(lt_3) | tasyncio.retry_if_exception(is_exc)
retry_strategy = retry_if_result(lt_3) | tasyncio.retry_if_exception(is_exc)
async for attempt in tasyncio.AsyncRetrying(retry=retry_strategy):
with attempt:
attempts += 1
if 1 < attempts < 3:
if 2 < attempts < 4:
raise CustomException()

assert attempt.retry_state.outcome # help mypy
Expand All @@ -231,7 +315,7 @@ async def is_exc(e: BaseException) -> bool:

result = await test()

self.assertEqual(3, result)
self.assertEqual(4, result)

@asynctest
async def test_retry_with_async_result_and(self):
Expand All @@ -256,6 +340,29 @@ def gt_0(x: float) -> bool:

self.assertEqual(3, result)

@asynctest
async def test_retry_with_async_result_rand(self):
async def test():
attempts = 0

async def lt_3(x: float) -> bool:
return x < 3

def gt_0(x: float) -> bool:
return x > 0

retry_strategy = retry_if_result(gt_0) & tasyncio.retry_if_result(lt_3)
async for attempt in tasyncio.AsyncRetrying(retry=retry_strategy):
with attempt:
attempts += 1
attempt.retry_state.set_result(attempts)

return attempts

result = await test()

self.assertEqual(3, result)

@asynctest
async def test_async_retying_iterator(self):
thing = NoIOErrorAfterCount(5)
Expand Down

0 comments on commit d0f11cb

Please sign in to comment.