diff --git a/releasenotes/notes/fix-local-context-overwrite-94190ba06a481631.yaml b/releasenotes/notes/fix-local-context-overwrite-94190ba06a481631.yaml new file mode 100644 index 0000000..ff2ba7e --- /dev/null +++ b/releasenotes/notes/fix-local-context-overwrite-94190ba06a481631.yaml @@ -0,0 +1,4 @@ +--- +fixes: + - | + Avoid overwriting local contexts when applying the retry decorator. diff --git a/tenacity/__init__.py b/tenacity/__init__.py index 7de36d4..06251ed 100644 --- a/tenacity/__init__.py +++ b/tenacity/__init__.py @@ -329,13 +329,19 @@ def wraps(self, f: WrappedFn) -> WrappedFn: f, functools.WRAPPER_ASSIGNMENTS + ("__defaults__", "__kwdefaults__") ) def wrapped_f(*args: t.Any, **kw: t.Any) -> t.Any: - return self(f, *args, **kw) + # Always create a copy to prevent overwriting the local contexts when + # calling the same wrapped functions multiple times in the same stack + copy = self.copy() + wrapped_f.statistics = copy.statistics # type: ignore[attr-defined] + return copy(f, *args, **kw) def retry_with(*args: t.Any, **kwargs: t.Any) -> WrappedFn: return self.copy(*args, **kwargs).wraps(f) - wrapped_f.retry = self # type: ignore[attr-defined] + # Preserve attributes + wrapped_f.retry = wrapped_f # type: ignore[attr-defined] wrapped_f.retry_with = retry_with # type: ignore[attr-defined] + wrapped_f.statistics = {} # type: ignore[attr-defined] return wrapped_f # type: ignore[return-value] diff --git a/tenacity/asyncio/__init__.py b/tenacity/asyncio/__init__.py index 6d63ebc..38b76c7 100644 --- a/tenacity/asyncio/__init__.py +++ b/tenacity/asyncio/__init__.py @@ -175,18 +175,23 @@ async def __anext__(self) -> AttemptManager: raise StopAsyncIteration def wraps(self, fn: WrappedFn) -> WrappedFn: - fn = super().wraps(fn) + wrapped = super().wraps(fn) # Ensure wrapper is recognized as a coroutine function. @functools.wraps( fn, functools.WRAPPER_ASSIGNMENTS + ("__defaults__", "__kwdefaults__") ) async def async_wrapped(*args: t.Any, **kwargs: t.Any) -> t.Any: - return await fn(*args, **kwargs) + # Always create a copy to prevent overwriting the local contexts when + # calling the same wrapped functions multiple times in the same stack + copy = self.copy() + async_wrapped.statistics = copy.statistics # type: ignore[attr-defined] + return await copy(fn, *args, **kwargs) # Preserve attributes - async_wrapped.retry = fn.retry # type: ignore[attr-defined] - async_wrapped.retry_with = fn.retry_with # type: ignore[attr-defined] + async_wrapped.retry = async_wrapped # type: ignore[attr-defined] + async_wrapped.retry_with = wrapped.retry_with # type: ignore[attr-defined] + async_wrapped.statistics = {} # type: ignore[attr-defined] return async_wrapped # type: ignore[return-value] diff --git a/tests/test_issue_478.py b/tests/test_issue_478.py new file mode 100644 index 0000000..7489ad7 --- /dev/null +++ b/tests/test_issue_478.py @@ -0,0 +1,118 @@ +import asyncio +import typing +import unittest + +from functools import wraps + +from tenacity import RetryCallState, retry + + +def asynctest( + callable_: typing.Callable[..., typing.Any], +) -> typing.Callable[..., typing.Any]: + @wraps(callable_) + def wrapper(*a: typing.Any, **kw: typing.Any) -> typing.Any: + loop = asyncio.get_event_loop() + return loop.run_until_complete(callable_(*a, **kw)) + + return wrapper + + +MAX_RETRY_FIX_ATTEMPTS = 2 + + +class TestIssue478(unittest.TestCase): + def test_issue(self) -> None: + results = [] + + def do_retry(retry_state: RetryCallState) -> bool: + outcome = retry_state.outcome + assert outcome + ex = outcome.exception() + _subject_: str = retry_state.args[0] + + if _subject_ == "Fix": # no retry on fix failure + return False + + if retry_state.attempt_number >= MAX_RETRY_FIX_ATTEMPTS: + return False + + if ex: + do_fix_work() + return True + + return False + + @retry(reraise=True, retry=do_retry) + def _do_work(subject: str) -> None: + if subject == "Error": + results.append(f"{subject} is not working") + raise Exception(f"{subject} is not working") + results.append(f"{subject} is working") + + def do_any_work(subject: str) -> None: + _do_work(subject) + + def do_fix_work() -> None: + _do_work("Fix") + + try: + do_any_work("Error") + except Exception as exc: + assert str(exc) == "Error is not working" + else: + assert False, "No exception caught" + + assert results == [ + "Error is not working", + "Fix is working", + "Error is not working", + ] + + @asynctest + async def test_async(self) -> None: + results = [] + + async def do_retry(retry_state: RetryCallState) -> bool: + outcome = retry_state.outcome + assert outcome + ex = outcome.exception() + _subject_: str = retry_state.args[0] + + if _subject_ == "Fix": # no retry on fix failure + return False + + if retry_state.attempt_number >= MAX_RETRY_FIX_ATTEMPTS: + return False + + if ex: + await do_fix_work() + return True + + return False + + @retry(reraise=True, retry=do_retry) + async def _do_work(subject: str) -> None: + if subject == "Error": + results.append(f"{subject} is not working") + raise Exception(f"{subject} is not working") + results.append(f"{subject} is working") + + async def do_any_work(subject: str) -> None: + await _do_work(subject) + + async def do_fix_work() -> None: + await _do_work("Fix") + + try: + await do_any_work("Error") + except Exception as exc: + assert str(exc) == "Error is not working" + else: + assert False, "No exception caught" + + assert results == [ + "Error is not working", + "Fix is working", + "Error is not working", + ]