Skip to content

Commit

Permalink
Review
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart committed Nov 12, 2024
1 parent 76d23fa commit 8b18582
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 58 deletions.
34 changes: 22 additions & 12 deletions src/anyio/_backends/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@
from ..abc._eventloop import StrOrBytesPath
from ..lowlevel import RunVar
from ..streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from ._selector_thread import _get_selector_windows

if sys.version_info >= (3, 10):
from typing import ParamSpec
Expand Down Expand Up @@ -2684,19 +2683,19 @@ async def wait_socket_readable(cls, sock: socket.socket) -> None:
raise BusyResourceError("reading from") from None

loop = get_running_loop()
if (
sys.platform == "win32"
and asyncio.get_event_loop_policy().__class__.__name__
== "WindowsProactorEventLoopPolicy"
):
add_reader = loop.add_reader
event = read_events[sock] = asyncio.Event()
try:
add_reader(sock, event.set)
except NotImplementedError:
# Proactor on Windows does not yet implement add/remove reader
from ._selector_thread import _get_selector_windows

selector = _get_selector_windows(loop)
add_reader = selector.add_reader
selector.add_reader(sock, event.set)
remove_reader = selector.remove_reader
else:
add_reader = loop.add_reader
remove_reader = loop.remove_reader
event = read_events[sock] = asyncio.Event()
add_reader(sock, event.set)
try:
await event.wait()
finally:
Expand All @@ -2722,13 +2721,24 @@ async def wait_socket_writable(cls, sock: socket.socket) -> None:
raise BusyResourceError("writing to") from None

loop = get_running_loop()
add_writer = loop.add_writer
event = write_events[sock] = asyncio.Event()
loop.add_writer(sock.fileno(), event.set)
try:
add_writer(sock.fileno(), event.set)
except NotImplementedError:
# Proactor on Windows does not yet implement add/remove writer
from ._selector_thread import _get_selector_windows

selector = _get_selector_windows(loop)
selector.add_writer(sock, event.set)
remove_writer = selector.remove_writer
else:
remove_writer = loop.remove_writer
try:
await event.wait()
finally:
if write_events.pop(sock, None) is not None:
loop.remove_writer(sock)
remove_writer(sock)
writable = True
else:
writable = False
Expand Down
42 changes: 8 additions & 34 deletions src/anyio/_backends/_selector_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from __future__ import annotations

import asyncio
import atexit
import errno
import functools
import select
Expand All @@ -21,6 +20,8 @@
)
from weakref import WeakKeyDictionary

from ._asyncio import find_root_task

if typing.TYPE_CHECKING:
from typing_extensions import Protocol

Expand All @@ -38,7 +39,7 @@ def fileno(self) -> int:
_selector_loops: set[SelectorThread] = set()


def _atexit_callback() -> None:
def _at_loop_close_callback(future: asyncio.Future) -> None:
for loop in _selector_loops:
with loop._select_cond:
loop._closing_selector = True
Expand All @@ -56,12 +57,7 @@ def _atexit_callback() -> None:
_selector_loops.clear()


atexit.register(_atexit_callback)


# SelectorThread from tornado 6.4.0


class SelectorThread:
"""Define ``add_reader`` methods to be called in a background select thread.
Expand All @@ -84,19 +80,6 @@ def __init__(self, real_loop: asyncio.AbstractEventLoop) -> None:
) = None
self._closing_selector = False
self._thread: threading.Thread | None = None
self._thread_manager_handle = self._thread_manager()

async def thread_manager_anext() -> None:
# the anext builtin wasn't added until 3.10. We just need to iterate
# this generator one step.
await self._thread_manager_handle.__anext__()

# When the loop starts, start the thread. Not too soon because we can't
# clean up if we get to this point but the event loop is closed without
# starting.
self._real_loop.call_soon(
lambda: self._real_loop.create_task(thread_manager_anext())
)

self._readers: dict[_FileDescriptorLike, Callable] = {}
self._writers: dict[_FileDescriptorLike, Callable] = {}
Expand All @@ -108,6 +91,7 @@ async def thread_manager_anext() -> None:
self._waker_w.setblocking(False)
_selector_loops.add(self)
self.add_reader(self._waker_r, self._consume_waker)
self._thread_manager()

def close(self) -> None:
if self._closed:
Expand All @@ -124,30 +108,19 @@ def close(self) -> None:
self._waker_w.close()
self._closed = True

async def _thread_manager(self) -> typing.AsyncGenerator[None, None]:
def _thread_manager(self) -> None:
# Create a thread to run the select system call. We manage this thread
# manually so we can trigger a clean shutdown from an atexit hook. Note
# manually so we can trigger a clean shutdown at loop teardown. Note
# that due to the order of operations at shutdown, only daemon threads
# can be shut down in this way (non-daemon threads would require the
# introduction of a new hook: https://bugs.python.org/issue41962)
self._thread = threading.Thread(
name="Tornado selector",
name="AnyIO selector",
daemon=True,
target=self._run_select,
)
self._thread.start()
self._start_select()
try:
# The presense of this yield statement means that this coroutine
# is actually an asynchronous generator, which has a special
# shutdown protocol. We wait at this yield point until the
# event loop's shutdown_asyncgens method is called, at which point
# we will get a GeneratorExit exception and can shut down the
# selector thread.
yield
except GeneratorExit:
self.close()
raise

def _wake_selector(self) -> None:
if self._closed:
Expand Down Expand Up @@ -298,6 +271,7 @@ def _get_selector_windows(
if asyncio_loop in _selectors:
return _selectors[asyncio_loop]

find_root_task().add_done_callback(_at_loop_close_callback)
selector_thread = _selectors[asyncio_loop] = SelectorThread(asyncio_loop)

# patch loop.close to also close the selector thread
Expand Down
21 changes: 9 additions & 12 deletions tests/test_sockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1858,17 +1858,14 @@ def client(port: int) -> None:
sock.connect(("127.0.0.1", port))
sock.sendall(b"Hello, world")

with move_on_after(0.1):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.bind(("127.0.0.1", 0))
port = sock.getsockname()[1]
sock.listen()
thread = Thread(target=client, args=(port,), daemon=True)
thread.start()
conn, addr = sock.accept()
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.bind(("127.0.0.1", 0))
port = sock.getsockname()[1]
sock.listen()
thread = Thread(target=client, args=(port,))
thread.start()
thread.join()
conn, addr = sock.accept()
with fail_after(5):
with conn:
await wait_socket_readable(conn)
socket_readable = True

assert socket_readable
thread.join()

0 comments on commit 8b18582

Please sign in to comment.