Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add test for starting sockets #4

Merged
merged 1 commit into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 22 additions & 13 deletions src/zmq_anyio/_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

import select
import selectors
import threading
import warnings
from collections import deque
from contextlib import AsyncExitStack
from functools import partial
from itertools import chain
from socket import socketpair
from threading import Event
from typing import (
Any,
Awaitable,
Expand All @@ -18,8 +18,8 @@
cast,
)

from anyio import create_task_group, from_thread, sleep, to_thread, wait_socket_readable
from anyio.abc import TaskGroup
from anyio import Event, TASK_STATUS_IGNORED, create_task_group, from_thread, sleep, to_thread, wait_socket_readable
from anyio.abc import TaskGroup, TaskStatus
from anyioutils import Future, Task, create_task

import zmq
Expand Down Expand Up @@ -166,9 +166,10 @@ class Socket(zmq.Socket):
_fd = None
_exit_stack = None
_task_group = None
_stop_event = None
_select_socket_r = None
_select_socket_w = None
_stopped = None
_started = None

def __init__(
self,
Expand All @@ -193,15 +194,16 @@ def __init__(
self._send_futures = deque()
self._state = 0
self._fd = self._shadow_sock.FD
self._stop_event = Event()
self._select_socket_r, self._select_socket_w = socketpair()
self._select_socket_r.setblocking(False)
self._select_socket_w.setblocking(False)
self._started = Event()
self._stopped = threading.Event()

def close(self, linger: int | None = None) -> None:
assert self._stop_event is not None
assert self._stopped is not None
assert self._select_socket_w is not None
self._stop_event.set()
self._stopped.set()
self._select_socket_w.send(b"a")
if not self.closed and self._fd is not None:
event_list: list[_FutureEvent] = list(
Expand Down Expand Up @@ -678,26 +680,33 @@ async def __aenter__(self) -> Socket:
async with AsyncExitStack() as exit_stack:
self._task_group = await exit_stack.enter_async_context(create_task_group())
self._exit_stack = exit_stack.pop_all()
self._task_group.start_soon(self.start)
await self._task_group.start(self.start)

return self

async def __aexit__(self, exc_type, exc_value, exc_tb):
try:
self.close()
except BaseException:
pass
self._task_group.cancel_scope.cancel()
self.close()
return await self._exit_stack.__aexit__(exc_type, exc_value, exc_tb)

async def start(self):
await to_thread.run_sync(self._reader, abandon_on_cancel=True)
#create_task(self._handle_events(task_group), task_group)
async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED) -> None:
assert self._task_group is not None
assert self._started is not None
self._task_group.start_soon(partial(to_thread.run_sync, self._reader, abandon_on_cancel=True))
await self._started.wait()
task_status.started()

def _reader(self):
from_thread.run_sync(self._started.set)
while True:
try:
rs, ws, xs = select.select([self._shadow_sock, self._select_socket_r.fileno()], [], [self._shadow_sock, self._select_socket_r.fileno()])
except OSError as e:
return
if self._stop_event.is_set():
if self._stopped.is_set():
return
self._read()

Expand Down
27 changes: 26 additions & 1 deletion tests/test_socket.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
import zmq
from anyio import create_task_group, sleep
from anyio import create_task_group, move_on_after, sleep, to_thread
from zmq_anyio import Socket

pytestmark = pytest.mark.anyio
Expand Down Expand Up @@ -81,3 +81,28 @@ async def recv():
tg.start_soon(recv)
await sleep(0.1)
a.send(b"hi")


@pytest.mark.parametrize("total_threads", [1, 2])
async def test_start_socket(total_threads, create_bound_pair):
to_thread.current_default_thread_limiter().total_tokens = total_threads

a, b = map(Socket, create_bound_pair(zmq.REQ, zmq.REP))
a_started = False
b_started = False

with pytest.raises(BaseException):
async with b:
b_started = True
with move_on_after(0.1):
async with a:
a_started = True
raise RuntimeError

assert b_started
if total_threads == 1:
assert not a_started
else:
assert a_started

to_thread.current_default_thread_limiter().total_tokens = 40