Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid refcycles in run exc #3120

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions demo.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you accidentally left this instead of copying parts over as a test.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, sorry I pushed this in draft so I don't lose it

Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import trio


async def main():
err = None
with trio.CancelScope() as scope:
scope.cancel()
try:
await trio.sleep_forever()
except BaseException as e:
err = e
raise
breakpoint()


# trio.run(main)

import gc

import objgraph
from anyio import CancelScope, get_cancelled_exc_class


async def test_exception_refcycles_propagate_cancellation_error() -> None:
"""Test that TaskGroup deletes cancelled_exc"""
exc = None

with CancelScope() as cs:
cs.cancel()
try:
await trio.sleep_forever()
except get_cancelled_exc_class() as e:
exc = e
raise

assert isinstance(exc, get_cancelled_exc_class())
gc.collect()
objgraph.show_chain(
objgraph.find_backref_chain(
gc.get_referrers(exc)[0],
objgraph.is_proper_module,
),
)


# trio.run(test_exception_refcycles_propagate_cancellation_error)


class MyException(Exception):
pass


async def main():
raise MyException


def inner():
try:
trio.run(main)
except MyException:
pass


import refcycle

gc.disable()
gc.collect()
inner()
garbage = refcycle.garbage()
for i, component in enumerate(garbage.source_components()):
component.export_image(f"{i}_example.svg")
garbage.export_image("example.svg")
19 changes: 13 additions & 6 deletions src/trio/_core/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -1706,6 +1706,7 @@ def close(self) -> None:
self.asyncgens.close()
if "after_run" in self.instruments:
self.instruments.call("after_run")
self.system_nursery: Nursery | None = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type hint here isn't necessary

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As long as this isn't the first time system_nursery is set yes

# This is where KI protection gets disabled, so we do it last
self.ki_manager.close()

Expand Down Expand Up @@ -1920,6 +1921,7 @@ def task_exited(self, task: Task, outcome: Outcome[Any]) -> None:
task._activate_cancel_status(None)
self.tasks.remove(task)
if task is self.init_task:
self.init_task = None
# If the init task crashed, then something is very wrong and we
# let the error propagate. (It'll eventually be wrapped in a
# TrioInternalError.)
Expand All @@ -1930,6 +1932,7 @@ def task_exited(self, task: Task, outcome: Outcome[Any]) -> None:
raise TrioInternalError
else:
if task is self.main_task:
self.main_task = None
self.main_task_outcome = outcome
outcome = Value(None)
assert task._parent_nursery is not None, task
Expand Down Expand Up @@ -2394,12 +2397,15 @@ def run(
sniffio_library.name = prev_library
# Inlined copy of runner.main_task_outcome.unwrap() to avoid
# cluttering every single Trio traceback with an extra frame.
if isinstance(runner.main_task_outcome, Value):
return cast(RetT, runner.main_task_outcome.value)
elif isinstance(runner.main_task_outcome, Error):
raise runner.main_task_outcome.error
else: # pragma: no cover
raise AssertionError(runner.main_task_outcome)
try:
if isinstance(runner.main_task_outcome, Value):
return cast(RetT, runner.main_task_outcome.value)
elif isinstance(runner.main_task_outcome, Error):
raise runner.main_task_outcome.error
else: # pragma: no cover
raise AssertionError(runner.main_task_outcome)
finally:
del runner


def start_guest_run(
Expand Down Expand Up @@ -2808,6 +2814,7 @@ def unrolled_run(
if isinstance(runner.main_task_outcome, Error):
ki.__context__ = runner.main_task_outcome.error
runner.main_task_outcome = Error(ki)
del runner


################################################################
Expand Down
Loading