Skip to content

Commit

Permalink
More mature version of utiloori.func.retry_exponential_backoff(), and…
Browse files Browse the repository at this point in the history
… test cases.
  • Loading branch information
uogbuji committed Nov 23, 2024
1 parent 4475b94 commit 7b85ce8
Show file tree
Hide file tree
Showing 3 changed files with 254 additions and 36 deletions.
171 changes: 135 additions & 36 deletions pylib/func.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,154 @@
import time
# SPDX-FileCopyrightText: 2023-present Oori Data <[email protected]>
# SPDX-License-Identifier: Apache-2.0
# utiloori.func

import asyncio
import random
from functools import wraps
from typing import Callable, Any, Optional, Tuple, Union, Type


def retry_exponential_backoff(
func,
def retry_with_exponential_backoff(
*decorator_args,
initial_delay: float = 1,
exponential_base: float = 2,
jitter: bool = True,
max_retries: int = 10,
errors: tuple = None,
retry_if_retval = None,
log_errors = None
max_delay: Optional[float] = None,
errors: Tuple[Type[Exception], ...] = (ConnectionError, TimeoutError),
retry_if_retval: Optional[Union[Callable[[Any], bool], Tuple[Any, ...]]] = None,
logger: Optional[Callable[[Exception], None]] = None
):
'''
Retry function call with exponential backoff
Decorates a function with exponential backoff retry mechanism.
Supports both synchronous and asynchronous functions. Retries the function
call with increasing delays between attempts, with optional jitter to prevent
[thundering herd problem](https://en.wikipedia.org/wiki/Thundering_herd_problem).
Args:
func (Callable): The function to be retried.
initial_delay (float, optional): Initial delay between retries. Defaults to 1 second.
exponential_base (float, optional): Base for exponential backoff. Defaults to 2.
jitter (bool, optional): Add randomness to delay to prevent synchronization. Defaults to True.
max_retries (int, optional): Maximum number of retry attempts. Defaults to 10.
max_delay (float, optional): Maximum delay between retries. Defaults to None.
errors (Tuple[Type[Exception], ...], optional): Exceptions that trigger a retry.
Defaults to (ConnectionError, TimeoutError).
retry_if_retval (Union[Callable[[Any], bool], Tuple[Any, ...]], optional):
Condition to retry based on return value. Can be a predicate function
or a tuple of values to retry.
logger (Callable[[Exception], None], optional): Optional logging function
for retry-related errors.
Returns:
The result of the function call after successful execution or raising
a RuntimeError if max retries are exceeded.
Raises:
RuntimeError: If maximum number of retries is exceeded.
# Sync function example
@retry_with_exponential_backoff
def sync_fetch():
# sync code here
# Async function example
@retry_with_exponential_backoff
async def async_fetch():
# async code here
# Passing custom parameters
result = sync_fetch(initial_delay=2, max_retries=5)
# Or with async function
result = await async_fetch(initial_delay=2, max_retries=5)
'''
# Allow decorator to be used with or without arguments
if len(decorator_args) == 1 and callable(decorator_args[0]):
func = decorator_args[0]
return _retry_wrapper(
func,
initial_delay,
exponential_base,
jitter,
max_retries,
max_delay,
errors,
retry_if_retval,
logger
)

def decorator(func):
return _retry_wrapper(
func,
initial_delay,
exponential_base,
jitter,
max_retries,
max_delay,
errors,
retry_if_retval,
logger
)

return decorator

def _retry_wrapper(
func: Callable,
initial_delay: float,
exponential_base: float,
jitter: bool,
max_retries: int,
max_delay: Optional[float],
errors: Tuple[Type[Exception], ...],
retry_if_retval: Optional[Union[Callable[[Any], bool], Tuple[Any, ...]]],
logger: Optional[Callable[[Exception], None]]
):
@wraps(func)
def wrapper(*args, **kwargs):
# Initialize variables
num_retries = 0
delay = initial_delay

# Loop until a successful response or max_retries is hit or an exception is raised
while True:
try:
retval = func(*args, **kwargs)
if retval in retry_if_retval:
err = ValueError(f'Unwanted return value {retval}')
else:
return retval

# Retry on specific errors
except errors as e:
err = e
async def async_wrapper():
num_retries = 0
delay = initial_delay

# Raise exceptions for any errors not specified
except Exception as e:
raise e
while True:
try:
# Check if the function is async
if asyncio.iscoroutinefunction(func):
retval = await func(*args, **kwargs)
else:
retval = func(*args, **kwargs)

# More flexible return value checking
if retry_if_retval:
should_retry = (
(callable(retry_if_retval) and retry_if_retval(retval)) or
(isinstance(retry_if_retval, tuple) and retval in retry_if_retval)
)
if should_retry:
raise ValueError(f'Unwanted return value {retval}')

return retval

if log_errors:
log_errors(err)
except errors as e:
if logger:
logger(e)

num_retries += 1
num_retries += 1
if num_retries > max_retries:
raise RuntimeError(f'Maximum retries ({max_retries}) exceeded.')

if num_retries > max_retries:
# Max retries has been reached
raise RuntimeError(f'Maximum number of retries ({max_retries}) exceeded.')
# Add max_delay constraint
delay *= exponential_base * (1 + jitter * random.random())
if max_delay:
delay = min(delay, max_delay)

# Backoff: Increment delay
delay *= exponential_base * (1 + jitter * random.random())
await asyncio.sleep(delay)

# Sleep for the delay
time.sleep(delay)
# Return a coroutine if the function is async, otherwise run the sync version
if asyncio.iscoroutinefunction(func):
return async_wrapper()
else:
return asyncio.run(async_wrapper())

return wrapper
3 changes: 3 additions & 0 deletions test/pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[pytest]
markers =
asyncio: asyncio mark
116 changes: 116 additions & 0 deletions test/test_func_retry_exp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import pytest
import random
from unittest.mock import Mock # , patch

from utiloori.func import retry_with_exponential_backoff

# Adding initial_delay=0.01 & exponential_base=1 throughout is to make sire it's fast, as a unit test should be


def test_sync_retry_success():
mock_counter = Mock(side_effect=[
ConnectionError("First failure"),
ConnectionError("Second failure"),
"Success"
])

# No, don't do this, because we want determinism for unit tests
# @retry_with_exponential_backoff(exponential_base=1.1)
# def flaky_sync_func():
# if random.random() < 0.7:
# raise ConnectionError("Simulated connection error")
# return "Success"

@retry_with_exponential_backoff(initial_delay=0.01, exponential_base=1)
def flaky_sync_func():
return mock_counter()

result = flaky_sync_func()
assert result == "Success"
assert mock_counter.call_count == 3


@pytest.mark.asyncio
async def test_async_retry_success():
mock_counter = Mock(side_effect=[
ConnectionError("First failure"),
"Success"
])

@retry_with_exponential_backoff(initial_delay=0.01, exponential_base=1)
async def flaky_async_func():
return mock_counter()

result = await flaky_async_func()
assert result == "Success"
assert mock_counter.call_count == 2


def test_max_retries_exceeded():
mock_counter = Mock(side_effect=ConnectionError("Always failing"))

@retry_with_exponential_backoff(max_retries=2, initial_delay=0.01, exponential_base=1)
def always_fail():
return mock_counter()

with pytest.raises(RuntimeError, match="Maximum retries"):
always_fail()
assert mock_counter.call_count == 3 # Initial try + 2 retries


def test_retry_on_return_value():
mock_counter = Mock(side_effect=[None, False, "Success"], errors=(ValueError,))

@retry_with_exponential_backoff(
retry_if_retval=(None, False),
errors=(ValueError,),
initial_delay=0.01,
exponential_base=1
)
def flaky_return_func():
return mock_counter()

result = flaky_return_func()
assert result == "Success"
assert mock_counter.call_count == 3


@pytest.mark.asyncio
async def test_custom_retry_params():
mock_counter = Mock(side_effect=[
ValueError("First failure"),
ValueError("Second failure"),
"Success"
])

@retry_with_exponential_backoff(
initial_delay=0.01,
max_retries=5,
exponential_base=1,
errors=(ValueError,)
)
async def custom_retry_func():
return mock_counter()

result = await custom_retry_func()
assert result == "Success"
assert mock_counter.call_count == 3


def test_different_exceptions():
mock_counter = Mock(side_effect=[
ConnectionError("Retry this"),
ValueError("Don't retry this")
])

@retry_with_exponential_backoff(
initial_delay=0.01,
errors=(ConnectionError,)
)
def specific_error_func():
return mock_counter()

with pytest.raises(ValueError):
specific_error_func()
assert mock_counter.call_count == 2

0 comments on commit 7b85ce8

Please sign in to comment.