diff --git a/ipykernel/iostream.py b/ipykernel/iostream.py index d8171017..b4736ca5 100644 --- a/ipykernel/iostream.py +++ b/ipykernel/iostream.py @@ -16,13 +16,15 @@ from binascii import b2a_hex from collections import defaultdict, deque from io import StringIO, TextIOBase -from threading import Event, Thread, local +from threading import local from typing import Any, Callable import zmq -from anyio import create_task_group, run, sleep, to_thread +from anyio import sleep from jupyter_client.session import extract_header +from .thread import BaseThread + # ----------------------------------------------------------------------------- # Globals # ----------------------------------------------------------------------------- @@ -37,38 +39,6 @@ # ----------------------------------------------------------------------------- -class _IOPubThread(Thread): - """A thread for a IOPub.""" - - def __init__(self, tasks, **kwargs): - """Initialize the thread.""" - super().__init__(name="IOPub", **kwargs) - self._tasks = tasks - self.pydev_do_not_trace = True - self.is_pydev_daemon_thread = True - self.daemon = True - self.__stop = Event() - - def run(self): - """Run the thread.""" - self.name = "IOPub" - run(self._main) - - async def _main(self): - async with create_task_group() as tg: - for task in self._tasks: - tg.start_soon(task) - await to_thread.run_sync(self.__stop.wait) - tg.cancel_scope.cancel() - - def stop(self): - """Stop the thread. - - This method is threadsafe. - """ - self.__stop.set() - - class IOPubThread: """An object for sending IOPub messages in a background thread @@ -111,7 +81,9 @@ def __init__(self, socket, pipe=False): tasks = [self._handle_event, self._run_event_pipe_gc] if pipe: tasks.append(self._handle_pipe_msgs) - self.thread = _IOPubThread(tasks) + self.thread = BaseThread(name="IOPub", daemon=True) + for task in tasks: + self.thread.start_soon(task) def _setup_event_pipe(self): """Create the PULL socket listening for events that should fire in this thread.""" @@ -181,7 +153,7 @@ async def _handle_event(self): event_f = self._events.popleft() event_f() except Exception: - if self.thread.__stop.is_set(): + if self.thread.stopped.is_set(): return raise @@ -215,7 +187,7 @@ async def _handle_pipe_msgs(self): while True: await self._handle_pipe_msg() except Exception: - if self.thread.__stop.is_set(): + if self.thread.stopped.is_set(): return raise diff --git a/ipykernel/kernelbase.py b/ipykernel/kernelbase.py index d496e0c9..095fc6d6 100644 --- a/ipykernel/kernelbase.py +++ b/ipykernel/kernelbase.py @@ -16,6 +16,7 @@ import uuid import warnings from datetime import datetime +from functools import partial from signal import SIGINT, SIGTERM, Signals from .thread import CONTROL_THREAD_NAME @@ -529,7 +530,7 @@ async def start(self, *, task_status: TaskStatus = TASK_STATUS_IGNORED) -> None: self.control_stop = threading.Event() if not self._is_test and self.control_socket is not None: if self.control_thread: - self.control_thread.add_task(self.control_main) + self.control_thread.start_soon(self.control_main) self.control_thread.start() else: tg.start_soon(self.control_main) @@ -544,9 +545,11 @@ async def start(self, *, task_status: TaskStatus = TASK_STATUS_IGNORED) -> None: # Assign tasks to and start shell channel thread. manager = self.shell_channel_thread.manager - self.shell_channel_thread.add_task(self.shell_channel_thread_main) - self.shell_channel_thread.add_task(manager.listen_from_control, self.shell_main) - self.shell_channel_thread.add_task(manager.listen_from_subshells) + self.shell_channel_thread.start_soon(self.shell_channel_thread_main) + self.shell_channel_thread.start_soon( + partial(manager.listen_from_control, self.shell_main) + ) + self.shell_channel_thread.start_soon(manager.listen_from_subshells) self.shell_channel_thread.start() else: if not self._is_test and self.shell_socket is not None: diff --git a/ipykernel/subshell_manager.py b/ipykernel/subshell_manager.py index 805d6f81..66caaafb 100644 --- a/ipykernel/subshell_manager.py +++ b/ipykernel/subshell_manager.py @@ -7,6 +7,7 @@ import typing as t import uuid from dataclasses import dataclass +from functools import partial from threading import Lock, current_thread, main_thread import zmq @@ -186,8 +187,8 @@ async def _create_subshell(self, subshell_task: t.Any) -> str: await self._send_stream.send(subshell_id) address = self._get_inproc_socket_address(subshell_id) - thread.add_task(thread.create_pair_socket, self._context, address) - thread.add_task(subshell_task, subshell_id) + thread.start_soon(partial(thread.create_pair_socket, self._context, address)) + thread.start_soon(partial(subshell_task, subshell_id)) thread.start() return subshell_id diff --git a/ipykernel/thread.py b/ipykernel/thread.py index 40509ece..4c9edf86 100644 --- a/ipykernel/thread.py +++ b/ipykernel/thread.py @@ -1,6 +1,10 @@ """Base class for threads.""" -import typing as t +from __future__ import annotations + +from collections.abc import Awaitable +from queue import Queue from threading import Event, Thread +from typing import Callable from anyio import create_task_group, run, to_thread @@ -14,24 +18,27 @@ class BaseThread(Thread): def __init__(self, **kwargs): """Initialize the thread.""" super().__init__(**kwargs) + self.started = Event() + self.stopped = Event() self.pydev_do_not_trace = True self.is_pydev_daemon_thread = True - self.__stop = Event() - self._tasks_and_args: list[tuple[t.Any, t.Any]] = [] + self._tasks: Queue[Callable[[], Awaitable[None]] | None] = Queue() - def add_task(self, task: t.Any, *args: t.Any) -> None: - # May only add tasks before the thread is started. - self._tasks_and_args.append((task, args)) + def start_soon(self, task: Callable[[], Awaitable[None]] | None) -> None: + self._tasks.put(task) - def run(self) -> t.Any: + def run(self) -> None: """Run the thread.""" - return run(self._main) + run(self._main) async def _main(self) -> None: async with create_task_group() as tg: - for task, args in self._tasks_and_args: - tg.start_soon(task, *args) - await to_thread.run_sync(self.__stop.wait) + self.started.set() + while True: + task = await to_thread.run_sync(self._tasks.get) + if task is None: + break + tg.start_soon(task) tg.cancel_scope.cancel() def stop(self) -> None: @@ -39,4 +46,5 @@ def stop(self) -> None: This method is threadsafe. """ - self.__stop.set() + self._tasks.put(None) + self.stopped.set()