diff --git a/src/zmq_anyio/_socket.py b/src/zmq_anyio/_socket.py index 9876083..7e7a43c 100644 --- a/src/zmq_anyio/_socket.py +++ b/src/zmq_anyio/_socket.py @@ -2,13 +2,13 @@ import select import selectors +import threading import warnings from collections import deque from contextlib import AsyncExitStack from functools import partial from itertools import chain from socket import socketpair -from threading import Event from typing import ( Any, Awaitable, @@ -18,8 +18,8 @@ cast, ) -from anyio import create_task_group, from_thread, sleep, to_thread, wait_socket_readable -from anyio.abc import TaskGroup +from anyio import Event, TASK_STATUS_IGNORED, create_task_group, from_thread, sleep, to_thread, wait_socket_readable +from anyio.abc import TaskGroup, TaskStatus from anyioutils import Future, Task, create_task import zmq @@ -166,9 +166,10 @@ class Socket(zmq.Socket): _fd = None _exit_stack = None _task_group = None - _stop_event = None _select_socket_r = None _select_socket_w = None + _stopped = None + _started = None def __init__( self, @@ -193,15 +194,16 @@ def __init__( self._send_futures = deque() self._state = 0 self._fd = self._shadow_sock.FD - self._stop_event = Event() self._select_socket_r, self._select_socket_w = socketpair() self._select_socket_r.setblocking(False) self._select_socket_w.setblocking(False) + self._started = Event() + self._stopped = threading.Event() def close(self, linger: int | None = None) -> None: - assert self._stop_event is not None + assert self._stopped is not None assert self._select_socket_w is not None - self._stop_event.set() + self._stopped.set() self._select_socket_w.send(b"a") if not self.closed and self._fd is not None: event_list: list[_FutureEvent] = list( @@ -678,26 +680,33 @@ async def __aenter__(self) -> Socket: async with AsyncExitStack() as exit_stack: self._task_group = await exit_stack.enter_async_context(create_task_group()) self._exit_stack = exit_stack.pop_all() - self._task_group.start_soon(self.start) + await self._task_group.start(self.start) return self async def __aexit__(self, exc_type, exc_value, exc_tb): + try: + self.close() + except BaseException: + pass self._task_group.cancel_scope.cancel() - self.close() return await self._exit_stack.__aexit__(exc_type, exc_value, exc_tb) - async def start(self): - await to_thread.run_sync(self._reader, abandon_on_cancel=True) - #create_task(self._handle_events(task_group), task_group) + async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED) -> None: + assert self._task_group is not None + assert self._started is not None + self._task_group.start_soon(partial(to_thread.run_sync, self._reader, abandon_on_cancel=True)) + await self._started.wait() + task_status.started() def _reader(self): + from_thread.run_sync(self._started.set) while True: try: rs, ws, xs = select.select([self._shadow_sock, self._select_socket_r.fileno()], [], [self._shadow_sock, self._select_socket_r.fileno()]) except OSError as e: return - if self._stop_event.is_set(): + if self._stopped.is_set(): return self._read() diff --git a/tests/test_socket.py b/tests/test_socket.py index c7d7857..a35620e 100644 --- a/tests/test_socket.py +++ b/tests/test_socket.py @@ -1,6 +1,6 @@ import pytest import zmq -from anyio import create_task_group, sleep +from anyio import create_task_group, move_on_after, sleep, to_thread from zmq_anyio import Socket pytestmark = pytest.mark.anyio @@ -81,3 +81,28 @@ async def recv(): tg.start_soon(recv) await sleep(0.1) a.send(b"hi") + + +@pytest.mark.parametrize("total_threads", [1, 2]) +async def test_start_socket(total_threads, create_bound_pair): + to_thread.current_default_thread_limiter().total_tokens = total_threads + + a, b = map(Socket, create_bound_pair(zmq.REQ, zmq.REP)) + a_started = False + b_started = False + + with pytest.raises(BaseException): + async with b: + b_started = True + with move_on_after(0.1): + async with a: + a_started = True + raise RuntimeError + + assert b_started + if total_threads == 1: + assert not a_started + else: + assert a_started + + to_thread.current_default_thread_limiter().total_tokens = 40