diff --git a/README.md b/README.md index 1b9aede..554bc04 100644 --- a/README.md +++ b/README.md @@ -3,3 +3,34 @@ # zmq-anyio Asynchronous API for ZMQ using AnyIO. + +## Usage + +`zmq_anyio.Socket` is a subclass of `zmq.Socket`. Here is how it must be used: +- Create a blocking ZMQ socket `sock` using a `zmq.Context`. +- Create an async `zmq_anyio.Socket(sock)`, passing the `sock`. +- Use the `zmq_anyio.Socket` with an async context manager. +- Use `arecv()` for the async API, `recv()` for the blocking API, etc. + +```py +import anyio +import zmq +import zmq_anyio + +ctx = zmq.Context() +sock1 = ctx.socket(zmq.PAIR) +port = sock1.bind("tcp://127.0.0.1:1234") +sock2 = ctx.socket(zmq.PAIR) +sock2.connect("tcp://127.0.0.1:1234") + +# wrap the `zmq.Socket` with `zmq_anyio.Socket`: +sock1 = zmq_anyio.Socket(sock1) +sock2 = zmq_anyio.Socket(sock2) + +async def main(): + async with sock1, sock2: # use an async context manager + await sock1.asend(b"Hello") # use `asend` instead of `send` + assert await sock2.arecv() == b"Hello" # use `arecv` instead of `recv` + +anyio.run(main) +``` diff --git a/src/zmq_anyio/_socket.py b/src/zmq_anyio/_socket.py index 1bd479f..c561ae3 100644 --- a/src/zmq_anyio/_socket.py +++ b/src/zmq_anyio/_socket.py @@ -212,8 +212,8 @@ def close(self, linger: int | None = None) -> None: def get(self, key): result = super().get(key) - # if key == EVENTS: - # self._schedule_remaining_events(result) + if key == EVENTS: + self._schedule_remaining_events(result) return result get.__doc__ = zmq.Socket.get.__doc__ @@ -429,45 +429,6 @@ async def asend_multipart( def _deserialize(self, recvd, load): """Deserialize with Futures""" return load(recvd) - # f = Future() - - # def _chain(_): - # """Chain result through serialization to recvd""" - # if f.done(): - # # chained future may be cancelled, which means nobody is going to get this result - # # if it's an error, that's no big deal (probably zmq.Again), - # # but if it's a successful recv, this is a dropped message! - # if not recvd.cancelled() and recvd.exception() is None: - # warnings.warn( - # # is there a useful stacklevel? - # # ideally, it would point to where `f.cancel()` was called - # f"Future {f} completed while awaiting {recvd}. A message has been dropped!", - # RuntimeWarning, - # ) - # return - # if recvd.exception(): - # f.set_exception(recvd.exception()) - # else: - # buf = recvd.result() - # try: - # loaded = load(buf) - # except Exception as e: - # f.set_exception(e) - # else: - # f.set_result(loaded) - - # recvd.add_done_callback(_chain) - - # def _chain_cancel(_): - # """Chain cancellation from f to recvd""" - # if recvd.done(): - # return - # if f.cancelled(): - # recvd.cancel() - - # f.add_done_callback(_chain_cancel) - - # return await f.wait() async def apoll(self, timeout=None, flags=zmq.POLLIN) -> int: # type: ignore """poll the socket for events @@ -839,6 +800,9 @@ async def _start(self, *, task_status: TaskStatus[None]): raise RuntimeError("Socket already started") self.started.set() task_status.started() - while True: - await wait_socket_readable(self._shadow_sock.FD) # type: ignore[arg-type] - await self._handle_events() + try: + while True: + await wait_socket_readable(self._shadow_sock.FD) # type: ignore[arg-type] + await self._handle_events() + except Exception: + pass diff --git a/tests/test_socket.py b/tests/test_socket.py index 00997dd..93e2983 100644 --- a/tests/test_socket.py +++ b/tests/test_socket.py @@ -40,7 +40,7 @@ async def recv(messages): async def test_arecv_json(context, create_bound_pair): a, b = create_bound_pair(zmq.PUSH, zmq.PULL) a, b = Socket(a), Socket(b) - async with b, a, create_task_group() as tg: + async with a, b, create_task_group() as tg: async def recv(messages): for message in messages: