-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
More mature version of utiloori.func.retry_exponential_backoff(), and…
… test cases.
- Loading branch information
Showing
3 changed files
with
254 additions
and
36 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,55 +1,154 @@ | ||
import time | ||
# SPDX-FileCopyrightText: 2023-present Oori Data <[email protected]> | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# utiloori.func | ||
|
||
import asyncio | ||
import random | ||
from functools import wraps | ||
from typing import Callable, Any, Optional, Tuple, Union, Type | ||
|
||
|
||
def retry_exponential_backoff( | ||
func, | ||
def retry_with_exponential_backoff( | ||
*decorator_args, | ||
initial_delay: float = 1, | ||
exponential_base: float = 2, | ||
jitter: bool = True, | ||
max_retries: int = 10, | ||
errors: tuple = None, | ||
retry_if_retval = None, | ||
log_errors = None | ||
max_delay: Optional[float] = None, | ||
errors: Tuple[Type[Exception], ...] = (ConnectionError, TimeoutError), | ||
retry_if_retval: Optional[Union[Callable[[Any], bool], Tuple[Any, ...]]] = None, | ||
logger: Optional[Callable[[Exception], None]] = None | ||
): | ||
''' | ||
Retry function call with exponential backoff | ||
Decorates a function with exponential backoff retry mechanism. | ||
Supports both synchronous and asynchronous functions. Retries the function | ||
call with increasing delays between attempts, with optional jitter to prevent | ||
[thundering herd problem](https://en.wikipedia.org/wiki/Thundering_herd_problem). | ||
Args: | ||
func (Callable): The function to be retried. | ||
initial_delay (float, optional): Initial delay between retries. Defaults to 1 second. | ||
exponential_base (float, optional): Base for exponential backoff. Defaults to 2. | ||
jitter (bool, optional): Add randomness to delay to prevent synchronization. Defaults to True. | ||
max_retries (int, optional): Maximum number of retry attempts. Defaults to 10. | ||
max_delay (float, optional): Maximum delay between retries. Defaults to None. | ||
errors (Tuple[Type[Exception], ...], optional): Exceptions that trigger a retry. | ||
Defaults to (ConnectionError, TimeoutError). | ||
retry_if_retval (Union[Callable[[Any], bool], Tuple[Any, ...]], optional): | ||
Condition to retry based on return value. Can be a predicate function | ||
or a tuple of values to retry. | ||
logger (Callable[[Exception], None], optional): Optional logging function | ||
for retry-related errors. | ||
Returns: | ||
The result of the function call after successful execution or raising | ||
a RuntimeError if max retries are exceeded. | ||
Raises: | ||
RuntimeError: If maximum number of retries is exceeded. | ||
# Sync function example | ||
@retry_with_exponential_backoff | ||
def sync_fetch(): | ||
# sync code here | ||
# Async function example | ||
@retry_with_exponential_backoff | ||
async def async_fetch(): | ||
# async code here | ||
# Passing custom parameters | ||
result = sync_fetch(initial_delay=2, max_retries=5) | ||
# Or with async function | ||
result = await async_fetch(initial_delay=2, max_retries=5) | ||
''' | ||
# Allow decorator to be used with or without arguments | ||
if len(decorator_args) == 1 and callable(decorator_args[0]): | ||
func = decorator_args[0] | ||
return _retry_wrapper( | ||
func, | ||
initial_delay, | ||
exponential_base, | ||
jitter, | ||
max_retries, | ||
max_delay, | ||
errors, | ||
retry_if_retval, | ||
logger | ||
) | ||
|
||
def decorator(func): | ||
return _retry_wrapper( | ||
func, | ||
initial_delay, | ||
exponential_base, | ||
jitter, | ||
max_retries, | ||
max_delay, | ||
errors, | ||
retry_if_retval, | ||
logger | ||
) | ||
|
||
return decorator | ||
|
||
def _retry_wrapper( | ||
func: Callable, | ||
initial_delay: float, | ||
exponential_base: float, | ||
jitter: bool, | ||
max_retries: int, | ||
max_delay: Optional[float], | ||
errors: Tuple[Type[Exception], ...], | ||
retry_if_retval: Optional[Union[Callable[[Any], bool], Tuple[Any, ...]]], | ||
logger: Optional[Callable[[Exception], None]] | ||
): | ||
@wraps(func) | ||
def wrapper(*args, **kwargs): | ||
# Initialize variables | ||
num_retries = 0 | ||
delay = initial_delay | ||
|
||
# Loop until a successful response or max_retries is hit or an exception is raised | ||
while True: | ||
try: | ||
retval = func(*args, **kwargs) | ||
if retval in retry_if_retval: | ||
err = ValueError(f'Unwanted return value {retval}') | ||
else: | ||
return retval | ||
|
||
# Retry on specific errors | ||
except errors as e: | ||
err = e | ||
async def async_wrapper(): | ||
num_retries = 0 | ||
delay = initial_delay | ||
|
||
# Raise exceptions for any errors not specified | ||
except Exception as e: | ||
raise e | ||
while True: | ||
try: | ||
# Check if the function is async | ||
if asyncio.iscoroutinefunction(func): | ||
retval = await func(*args, **kwargs) | ||
else: | ||
retval = func(*args, **kwargs) | ||
|
||
# More flexible return value checking | ||
if retry_if_retval: | ||
should_retry = ( | ||
(callable(retry_if_retval) and retry_if_retval(retval)) or | ||
(isinstance(retry_if_retval, tuple) and retval in retry_if_retval) | ||
) | ||
if should_retry: | ||
raise ValueError(f'Unwanted return value {retval}') | ||
|
||
return retval | ||
|
||
if log_errors: | ||
log_errors(err) | ||
except errors as e: | ||
if logger: | ||
logger(e) | ||
|
||
num_retries += 1 | ||
num_retries += 1 | ||
if num_retries > max_retries: | ||
raise RuntimeError(f'Maximum retries ({max_retries}) exceeded.') | ||
|
||
if num_retries > max_retries: | ||
# Max retries has been reached | ||
raise RuntimeError(f'Maximum number of retries ({max_retries}) exceeded.') | ||
# Add max_delay constraint | ||
delay *= exponential_base * (1 + jitter * random.random()) | ||
if max_delay: | ||
delay = min(delay, max_delay) | ||
|
||
# Backoff: Increment delay | ||
delay *= exponential_base * (1 + jitter * random.random()) | ||
await asyncio.sleep(delay) | ||
|
||
# Sleep for the delay | ||
time.sleep(delay) | ||
# Return a coroutine if the function is async, otherwise run the sync version | ||
if asyncio.iscoroutinefunction(func): | ||
return async_wrapper() | ||
else: | ||
return asyncio.run(async_wrapper()) | ||
|
||
return wrapper |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
[pytest] | ||
markers = | ||
asyncio: asyncio mark |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
import pytest | ||
import random | ||
from unittest.mock import Mock # , patch | ||
|
||
from utiloori.func import retry_with_exponential_backoff | ||
|
||
# Adding initial_delay=0.01 & exponential_base=1 throughout is to make sire it's fast, as a unit test should be | ||
|
||
|
||
def test_sync_retry_success(): | ||
mock_counter = Mock(side_effect=[ | ||
ConnectionError("First failure"), | ||
ConnectionError("Second failure"), | ||
"Success" | ||
]) | ||
|
||
# No, don't do this, because we want determinism for unit tests | ||
# @retry_with_exponential_backoff(exponential_base=1.1) | ||
# def flaky_sync_func(): | ||
# if random.random() < 0.7: | ||
# raise ConnectionError("Simulated connection error") | ||
# return "Success" | ||
|
||
@retry_with_exponential_backoff(initial_delay=0.01, exponential_base=1) | ||
def flaky_sync_func(): | ||
return mock_counter() | ||
|
||
result = flaky_sync_func() | ||
assert result == "Success" | ||
assert mock_counter.call_count == 3 | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_async_retry_success(): | ||
mock_counter = Mock(side_effect=[ | ||
ConnectionError("First failure"), | ||
"Success" | ||
]) | ||
|
||
@retry_with_exponential_backoff(initial_delay=0.01, exponential_base=1) | ||
async def flaky_async_func(): | ||
return mock_counter() | ||
|
||
result = await flaky_async_func() | ||
assert result == "Success" | ||
assert mock_counter.call_count == 2 | ||
|
||
|
||
def test_max_retries_exceeded(): | ||
mock_counter = Mock(side_effect=ConnectionError("Always failing")) | ||
|
||
@retry_with_exponential_backoff(max_retries=2, initial_delay=0.01, exponential_base=1) | ||
def always_fail(): | ||
return mock_counter() | ||
|
||
with pytest.raises(RuntimeError, match="Maximum retries"): | ||
always_fail() | ||
assert mock_counter.call_count == 3 # Initial try + 2 retries | ||
|
||
|
||
def test_retry_on_return_value(): | ||
mock_counter = Mock(side_effect=[None, False, "Success"], errors=(ValueError,)) | ||
|
||
@retry_with_exponential_backoff( | ||
retry_if_retval=(None, False), | ||
errors=(ValueError,), | ||
initial_delay=0.01, | ||
exponential_base=1 | ||
) | ||
def flaky_return_func(): | ||
return mock_counter() | ||
|
||
result = flaky_return_func() | ||
assert result == "Success" | ||
assert mock_counter.call_count == 3 | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_custom_retry_params(): | ||
mock_counter = Mock(side_effect=[ | ||
ValueError("First failure"), | ||
ValueError("Second failure"), | ||
"Success" | ||
]) | ||
|
||
@retry_with_exponential_backoff( | ||
initial_delay=0.01, | ||
max_retries=5, | ||
exponential_base=1, | ||
errors=(ValueError,) | ||
) | ||
async def custom_retry_func(): | ||
return mock_counter() | ||
|
||
result = await custom_retry_func() | ||
assert result == "Success" | ||
assert mock_counter.call_count == 3 | ||
|
||
|
||
def test_different_exceptions(): | ||
mock_counter = Mock(side_effect=[ | ||
ConnectionError("Retry this"), | ||
ValueError("Don't retry this") | ||
]) | ||
|
||
@retry_with_exponential_backoff( | ||
initial_delay=0.01, | ||
errors=(ConnectionError,) | ||
) | ||
def specific_error_func(): | ||
return mock_counter() | ||
|
||
with pytest.raises(ValueError): | ||
specific_error_func() | ||
assert mock_counter.call_count == 2 | ||
|