Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use Contextmanagers to handle StopIteration in generators #12934

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions src/_pytest/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,7 @@ def pytest_runtest_logstart(self) -> None:
def pytest_runtest_logreport(self) -> None:
self.log_cli_handler.set_when("logreport")

@contextmanager
def _runtest_for(self, item: nodes.Item, when: str) -> Generator[None]:
"""Implement the internals of the pytest_runtest_xxx() hooks."""
with catching_logs(
Expand All @@ -837,20 +838,23 @@ def pytest_runtest_setup(self, item: nodes.Item) -> Generator[None]:

empty: dict[str, list[logging.LogRecord]] = {}
item.stash[caplog_records_key] = empty
yield from self._runtest_for(item, "setup")
with self._runtest_for(item, "setup"):
yield

@hookimpl(wrapper=True)
def pytest_runtest_call(self, item: nodes.Item) -> Generator[None]:
self.log_cli_handler.set_when("call")

yield from self._runtest_for(item, "call")
with self._runtest_for(item, "call"):
yield

@hookimpl(wrapper=True)
def pytest_runtest_teardown(self, item: nodes.Item) -> Generator[None]:
self.log_cli_handler.set_when("teardown")

try:
yield from self._runtest_for(item, "teardown")
with self._runtest_for(item, "teardown"):
yield
finally:
del item.stash[caplog_records_key]
del item.stash[caplog_handler_key]
Expand Down
18 changes: 12 additions & 6 deletions src/_pytest/threadexception.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from __future__ import annotations

from contextlib import contextmanager
import threading
import traceback
from types import TracebackType
from typing import Any
from typing import Callable
from typing import Generator
from typing import Iterator
from typing import TYPE_CHECKING
import warnings

Expand Down Expand Up @@ -62,6 +64,7 @@ def __exit__(
del self.args


@contextmanager
def thread_exception_runtest_hook() -> Generator[None]:
with catch_threading_exception() as cm:
try:
Expand All @@ -83,15 +86,18 @@ def thread_exception_runtest_hook() -> Generator[None]:


@pytest.hookimpl(wrapper=True, trylast=True)
def pytest_runtest_setup() -> Generator[None]:
yield from thread_exception_runtest_hook()
def pytest_runtest_setup() -> Iterator[None]:
with thread_exception_runtest_hook():
return (yield)


@pytest.hookimpl(wrapper=True, tryfirst=True)
def pytest_runtest_call() -> Generator[None]:
yield from thread_exception_runtest_hook()
def pytest_runtest_call() -> Iterator[None]:
with thread_exception_runtest_hook():
return (yield)


@pytest.hookimpl(wrapper=True, tryfirst=True)
def pytest_runtest_teardown() -> Generator[None]:
yield from thread_exception_runtest_hook()
def pytest_runtest_teardown() -> Iterator[None]:
with thread_exception_runtest_hook():
return (yield)
63 changes: 31 additions & 32 deletions src/_pytest/unraisableexception.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from types import TracebackType
from typing import Any
from typing import Callable
from typing import Generator
from typing import Iterator
from typing import TYPE_CHECKING
import warnings

Expand Down Expand Up @@ -38,15 +38,30 @@
# (to break a reference cycle)
"""

def __init__(self) -> None:
self.unraisable: sys.UnraisableHookArgs | None = None
self._old_hook: Callable[[sys.UnraisableHookArgs], Any] | None = None
unraisable: sys.UnraisableHookArgs | None = None
_old_hook: Callable[[sys.UnraisableHookArgs], Any] | None = None

def _hook(self, unraisable: sys.UnraisableHookArgs) -> None:
# Storing unraisable.object can resurrect an object which is being
# finalized. Storing unraisable.exc_value creates a reference cycle.
self.unraisable = unraisable

def _warn_if_triggered(self) -> None:
if self.unraisable:
if self.unraisable.err_msg is not None:
err_msg = self.unraisable.err_msg

Check warning on line 52 in src/_pytest/unraisableexception.py

View check run for this annotation

Codecov / codecov/patch

src/_pytest/unraisableexception.py#L52

Added line #L52 was not covered by tests
else:
err_msg = "Exception ignored in"
msg = f"{err_msg}: {self.unraisable.object!r}\n\n"
msg += "".join(
traceback.format_exception(
self.unraisable.exc_type,
self.unraisable.exc_value,
self.unraisable.exc_traceback,
)
)
warnings.warn(pytest.PytestUnraisableExceptionWarning(msg))

def __enter__(self) -> Self:
self._old_hook = sys.unraisablehook
sys.unraisablehook = self._hook
Expand All @@ -61,40 +76,24 @@
assert self._old_hook is not None
sys.unraisablehook = self._old_hook
self._old_hook = None
del self.unraisable


def unraisable_exception_runtest_hook() -> Generator[None]:
with catch_unraisable_exception() as cm:
try:
yield
finally:
if cm.unraisable:
if cm.unraisable.err_msg is not None:
err_msg = cm.unraisable.err_msg
else:
err_msg = "Exception ignored in"
msg = f"{err_msg}: {cm.unraisable.object!r}\n\n"
msg += "".join(
traceback.format_exception(
cm.unraisable.exc_type,
cm.unraisable.exc_value,
cm.unraisable.exc_traceback,
)
)
warnings.warn(pytest.PytestUnraisableExceptionWarning(msg))
self._warn_if_triggered()
if "unraisable" in vars(self):
del self.unraisable


@pytest.hookimpl(wrapper=True, tryfirst=True)
def pytest_runtest_setup() -> Generator[None]:
yield from unraisable_exception_runtest_hook()
def pytest_runtest_setup() -> Iterator[None]:
with catch_unraisable_exception():
yield


@pytest.hookimpl(wrapper=True, tryfirst=True)
def pytest_runtest_call() -> Generator[None]:
yield from unraisable_exception_runtest_hook()
def pytest_runtest_call() -> Iterator[None]:
with catch_unraisable_exception():
yield


@pytest.hookimpl(wrapper=True, tryfirst=True)
def pytest_runtest_teardown() -> Generator[None]:
yield from unraisable_exception_runtest_hook()
def pytest_runtest_teardown() -> Iterator[None]:
with catch_unraisable_exception():
yield
10 changes: 10 additions & 0 deletions testing/example_scripts/hook_exceptions/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from __future__ import annotations

from typing import Iterator

import pytest


@pytest.hookimpl(wrapper=True)
def pytest_runtest_call() -> Iterator[None]:
yield
Empty file.
87 changes: 87 additions & 0 deletions testing/example_scripts/hook_exceptions/test_stop_iteration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""
test example file exposing mltiple issues with corutine exception passover in case of stopiteration

the stdlib contextmanager implementation explicitly catches
and reshapes in case a StopIteration was send in and is raised out
"""

from __future__ import annotations

from collections.abc import Iterator
from contextlib import contextmanager

import pluggy


def test_stop() -> None:
raise StopIteration()


hookspec = pluggy.HookspecMarker("myproject")
hookimpl = pluggy.HookimplMarker("myproject")


class MySpec:
"""A hook specification namespace."""

@hookspec
def myhook(self, arg1: int, arg2: int) -> int: # type: ignore[empty-body]
"""My special little hook that you can customize."""


class Plugin_1:
"""A hook implementation namespace."""

@hookimpl
def myhook(self, arg1: int, arg2: int) -> int:
print("inside Plugin_1.myhook()")
raise StopIteration()


class Plugin_2:
"""A 2nd hook implementation namespace."""

@hookimpl(wrapper=True)
def myhook(self) -> Iterator[None]:
return (yield)


def try_pluggy() -> None:
# create a manager and add the spec
pm = pluggy.PluginManager("myproject")
pm.add_hookspecs(MySpec)

# register plugins
pm.register(Plugin_1())
pm.register(Plugin_2())

# call our ``myhook`` hook
results = pm.hook.myhook(arg1=1, arg2=2)
print(results)


@contextmanager
def my_cm() -> Iterator[None]:
try:
yield
except Exception as e:
print(e)
raise StopIteration()


def inner() -> None:
with my_cm():
raise StopIteration()


def try_context() -> None:
inner()


mains = {"pluggy": try_pluggy, "context": try_context}

if __name__ == "__main__":
import sys

if len(sys.argv) == 2:
mains[sys.argv[1]]()
Loading