Skip to content

Commit

Permalink
Merge pull request #3112 from graingert/dont-use-f_locals-for-agen-fi…
Browse files Browse the repository at this point in the history
…nalization
  • Loading branch information
graingert authored Nov 21, 2024
2 parents 7334c42 + ca6e01c commit 34e7a2e
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 21 deletions.
5 changes: 5 additions & 0 deletions newsfragments/3112.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Rework foreign async generator finalization to track async generator
ids rather than mutating ``ag_frame.f_locals``. This fixes an issue
with the previous implementation: locals' lifetimes will no longer be
extended by materialization in the ``ag_frame.f_locals`` dictionary that
the previous finalization dispatcher logic needed to access to do its work.
46 changes: 36 additions & 10 deletions src/trio/_core/_asyncgens.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import sys
import warnings
import weakref
from typing import TYPE_CHECKING, NoReturn
from typing import TYPE_CHECKING, NoReturn, TypeVar

import attrs

Expand All @@ -16,14 +16,31 @@
ASYNCGEN_LOGGER = logging.getLogger("trio.async_generator_errors")

if TYPE_CHECKING:
from collections.abc import Callable
from types import AsyncGeneratorType

from typing_extensions import ParamSpec

_P = ParamSpec("_P")

_WEAK_ASYNC_GEN_SET = weakref.WeakSet[AsyncGeneratorType[object, NoReturn]]
_ASYNC_GEN_SET = set[AsyncGeneratorType[object, NoReturn]]
else:
_WEAK_ASYNC_GEN_SET = weakref.WeakSet
_ASYNC_GEN_SET = set

_R = TypeVar("_R")


@_core.disable_ki_protection
def _call_without_ki_protection(
f: Callable[_P, _R],
/,
*args: _P.args,
**kwargs: _P.kwargs,
) -> _R:
return f(*args, **kwargs)


@attrs.define(eq=False)
class AsyncGenerators:
Expand All @@ -35,6 +52,11 @@ class AsyncGenerators:
# regular set so we don't have to deal with GC firing at
# unexpected times.
alive: _WEAK_ASYNC_GEN_SET | _ASYNC_GEN_SET = attrs.Factory(_WEAK_ASYNC_GEN_SET)
# The ids of foreign async generators are added to this set when first
# iterated. Usually it is not safe to refer to ids like this, but because
# we're using a finalizer we can ensure ids in this set do not outlive
# their async generator.
foreign: set[int] = attrs.Factory(set)

# This collects async generators that get garbage collected during
# the one-tick window between the system nursery closing and the
Expand All @@ -51,10 +73,10 @@ def firstiter(agen: AsyncGeneratorType[object, NoReturn]) -> None:
# An async generator first iterated outside of a Trio
# task doesn't belong to Trio. Probably we're in guest
# mode and the async generator belongs to our host.
# The locals dictionary is the only good place to
# A strong set of ids is one of the only good places to
# remember this fact, at least until
# https://bugs.python.org/issue40916 is implemented.
agen.ag_frame.f_locals["@trio_foreign_asyncgen"] = True
# https://github.com/python/cpython/issues/85093 is implemented.
self.foreign.add(id(agen))
if self.prev_hooks.firstiter is not None:
self.prev_hooks.firstiter(agen)

Expand All @@ -76,13 +98,16 @@ def finalize_in_trio_context(
# have hit it.
self.trailing_needs_finalize.add(agen)

@_core.enable_ki_protection
def finalizer(agen: AsyncGeneratorType[object, NoReturn]) -> None:
agen_name = name_asyncgen(agen)
try:
is_ours = not agen.ag_frame.f_locals.get("@trio_foreign_asyncgen")
except AttributeError: # pragma: no cover
self.foreign.remove(id(agen))
except KeyError:
is_ours = True
else:
is_ours = False

agen_name = name_asyncgen(agen)
if is_ours:
runner.entry_queue.run_sync_soon(
finalize_in_trio_context,
Expand All @@ -105,8 +130,9 @@ def finalizer(agen: AsyncGeneratorType[object, NoReturn]) -> None:
)
else:
# Not ours -> forward to the host loop's async generator finalizer
if self.prev_hooks.finalizer is not None:
self.prev_hooks.finalizer(agen)
finalizer = self.prev_hooks.finalizer
if finalizer is not None:
_call_without_ki_protection(finalizer, agen)
else:
# Host has no finalizer. Reimplement the default
# Python behavior with no hooks installed: throw in
Expand All @@ -116,7 +142,7 @@ def finalizer(agen: AsyncGeneratorType[object, NoReturn]) -> None:
try:
# If the next thing is a yield, this will raise RuntimeError
# which we allow to propagate
closer.send(None)
_call_without_ki_protection(closer.send, None)
except StopIteration:
pass
else:
Expand Down
64 changes: 53 additions & 11 deletions src/trio/_core/_tests/test_guest_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import asyncio
import contextlib
import contextvars
import queue
import signal
import socket
Expand All @@ -11,6 +10,7 @@
import time
import traceback
import warnings
import weakref
from collections.abc import AsyncGenerator, Awaitable, Callable, Sequence
from functools import partial
from math import inf
Expand All @@ -22,6 +22,7 @@
)

import pytest
import sniffio
from outcome import Outcome

import trio
Expand Down Expand Up @@ -234,7 +235,8 @@ async def trio_main(in_host: InHost) -> str:


def test_guest_mode_sniffio_integration() -> None:
from sniffio import current_async_library, thread_local as sniffio_library
current_async_library = sniffio.current_async_library
sniffio_library = sniffio.thread_local

async def trio_main(in_host: InHost) -> str:
async def synchronize() -> None:
Expand Down Expand Up @@ -458,9 +460,9 @@ def aiotrio_run(

async def aio_main() -> T:
nonlocal run_sync_soon_not_threadsafe
trio_done_fut = loop.create_future()
trio_done_fut: asyncio.Future[Outcome[T]] = loop.create_future()

def trio_done_callback(main_outcome: Outcome[object]) -> None:
def trio_done_callback(main_outcome: Outcome[T]) -> None:
print(f"trio_fn finished: {main_outcome!r}")
trio_done_fut.set_result(main_outcome)

Expand All @@ -479,9 +481,11 @@ def trio_done_callback(main_outcome: Outcome[object]) -> None:
strict_exception_groups=strict_exception_groups,
)

return (await trio_done_fut).unwrap() # type: ignore[no-any-return]
return (await trio_done_fut).unwrap()

try:
# can't use asyncio.run because that fails on Windows (3.8, x64, with
# Komodia LSP) and segfaults on Windows (3.9, x64, with Komodia LSP)
return loop.run_until_complete(aio_main())
finally:
loop.close()
Expand Down Expand Up @@ -655,8 +659,6 @@ async def trio_main(in_host: InHost) -> None:

@restore_unraisablehook()
def test_guest_mode_asyncgens() -> None:
import sniffio

record = set()

async def agen(label: str) -> AsyncGenerator[int, None]:
Expand All @@ -683,9 +685,49 @@ async def trio_main() -> None:

gc_collect_harder()

# Ensure we don't pollute the thread-level context if run under
# an asyncio without contextvars support (3.6)
context = contextvars.copy_context()
context.run(aiotrio_run, trio_main, host_uses_signal_set_wakeup_fd=True)
aiotrio_run(trio_main, host_uses_signal_set_wakeup_fd=True)

assert record == {("asyncio", "asyncio"), ("trio", "trio")}


@restore_unraisablehook()
def test_guest_mode_asyncgens_garbage_collection() -> None:
record: set[tuple[str, str, bool]] = set()

async def agen(label: str) -> AsyncGenerator[int, None]:
class A:
pass

a = A()
a_wr = weakref.ref(a)
assert sniffio.current_async_library() == label
try:
yield 1
finally:
library = sniffio.current_async_library()
with contextlib.suppress(trio.Cancelled):
await sys.modules[library].sleep(0)

del a
if sys.implementation.name == "pypy":
gc_collect_harder()

record.add((label, library, a_wr() is None))

async def iterate_in_aio() -> None:
await agen("asyncio").asend(None)

async def trio_main() -> None:
task = asyncio.ensure_future(iterate_in_aio())
done_evt = trio.Event()
task.add_done_callback(lambda _: done_evt.set())
with trio.fail_after(1):
await done_evt.wait()

await agen("trio").asend(None)

gc_collect_harder()

aiotrio_run(trio_main, host_uses_signal_set_wakeup_fd=True)

assert record == {("asyncio", "asyncio", True), ("trio", "trio", True)}

0 comments on commit 34e7a2e

Please sign in to comment.