From 336e2b64224435c5a9cc8b924e1941df4370ffe8 Mon Sep 17 00:00:00 2001 From: Norman Rzepka Date: Thu, 16 Nov 2023 14:27:54 +0100 Subject: [PATCH] More dask features (#959) * upgrades to mypy 1.6 * pr feedback * changelog * adds sigint, mem and cpus support * changelog * weakref handle_kill * test dask in CI * typing * ci * ci * fix tests --- .github/workflows/ci.yml | 10 +- cluster_tools/Changelog.md | 4 + cluster_tools/cluster_tools/executors/dask.py | 151 +++++++++++++++++- .../schedulers/cluster_executor.py | 18 ++- cluster_tools/tests/test_all.py | 14 +- 5 files changed, 183 insertions(+), 14 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 41a9972be..713f104fd 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -37,7 +37,7 @@ jobs: strategy: max-parallel: 4 matrix: - executors: [multiprocessing, slurm, kubernetes] + executors: [multiprocessing, slurm, kubernetes, dask] python-version: ["3.11", "3.10", "3.9", "3.8"] defaults: run: @@ -88,7 +88,7 @@ jobs: ./kind load docker-image scalableminds/cluster-tools:latest - name: Install dependencies (without docker) - if: ${{ matrix.executors == 'multiprocessing' || matrix.executors == 'kubernetes' }} + if: ${{ matrix.executors != 'slurm' }} run: | pip install -r ../requirements.txt poetry install @@ -130,6 +130,12 @@ jobs: cd tests PYTEST_EXECUTORS=kubernetes poetry run python -m pytest -sv test_all.py test_kubernetes.py + - name: Run dask tests + if: ${{ matrix.executors == 'dask' && matrix.python-version != '3.8' }} + run: | + cd tests + PYTEST_EXECUTORS=dask poetry run python -m pytest -sv test_all.py + webknossos_linux: needs: changes if: | diff --git a/cluster_tools/Changelog.md b/cluster_tools/Changelog.md index ec804419e..73b1adae9 100644 --- a/cluster_tools/Changelog.md +++ b/cluster_tools/Changelog.md @@ -12,8 +12,12 @@ For upgrade instructions, please check the respective *Breaking Changes* section ### Breaking Changes ### Added +- Added SIGINT handling to `DaskExecutor`. [#959](https://github.com/scalableminds/webknossos-libs/pull/959) +- Added support for resources (e.g. mem, cpus) to `DaskExecutor`. [#959](https://github.com/scalableminds/webknossos-libs/pull/959) +- The cluster address for the `DaskExecutor` can be configured via the `DASK_ADDRESS` env var. [#959](https://github.com/scalableminds/webknossos-libs/pull/959) ### Changed +- Tasks using the `DaskExecutor` are run in their own process. This is required to not block the GIL for the dask worker to communicate with the scheduler. Env variables are propagated to the task processes. [#959](https://github.com/scalableminds/webknossos-libs/pull/959) ### Fixed diff --git a/cluster_tools/cluster_tools/executors/dask.py b/cluster_tools/cluster_tools/executors/dask.py index 9c4115ce8..192501e24 100644 --- a/cluster_tools/cluster_tools/executors/dask.py +++ b/cluster_tools/cluster_tools/executors/dask.py @@ -1,7 +1,11 @@ import os +import re +import signal +import traceback from concurrent import futures from concurrent.futures import Future from functools import partial +from multiprocessing import Queue, get_context from typing import ( TYPE_CHECKING, Any, @@ -11,9 +15,11 @@ Iterator, List, Optional, + Set, TypeVar, cast, ) +from weakref import ReferenceType, ref from typing_extensions import ParamSpec @@ -28,23 +34,119 @@ _S = TypeVar("_S") +def _run_in_nanny( + queue: Queue, __fn: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs +) -> None: + try: + __env = cast(Dict[str, str], kwargs.pop("__env")) + for key, value in __env.items(): + os.environ[key] = value + + ret = __fn(*args, **kwargs) + queue.put({"value": ret}) + except Exception as exc: + queue.put({"exception": exc}) + + +def _run_with_nanny( + __fn: Callable[_P, _T], + *args: _P.args, + **kwargs: _P.kwargs, +) -> _T: + mp_context = get_context("spawn") + q = mp_context.Queue() + p = mp_context.Process(target=_run_in_nanny, args=(q, __fn) + args, kwargs=kwargs) + p.start() + p.join() + ret = q.get(timeout=0.1) + if "exception" in ret: + raise ret["exception"] + else: + return ret["value"] + + +def _parse_mem(size: str) -> int: + units = {"": 1, "K": 2**10, "M": 2**20, "G": 2**30, "T": 2**40} + m = re.match(r"^([\d\.]+)\s*([kmgtKMGT]{0,1})$", str(size).strip()) + assert m is not None, f"Could not parse {size}" + number, unit = float(m.group(1)), m.group(2).upper() + assert unit in units + return int(number * units[unit]) + + +def _handle_kill_through_weakref( + executor_ref: "ReferenceType[DaskExecutor]", + existing_sigint_handler: Any, + signum: Optional[int], + frame: Any, +) -> None: + executor = executor_ref() + if executor is None: + return + executor.handle_kill(existing_sigint_handler, signum, frame) + + class DaskExecutor(futures.Executor): + """ + The `DaskExecutor` allows to run workloads on a dask cluster. + + The executor can be constructed with an existing dask `Client` or + from a declarative configuration. The address of the dask scheduler + can be part of the configuration or supplied as environment variable + `DASK_ADDRESS`. + + There is support for resource-based scheduling. As default, `mem` and + `cpus-per-task` are supported. To make use of them, the dask workers + should be started with: + `python -m dask worker --no-nanny --nthreads 6 tcp://... --resources "mem=1073741824 cpus=8"` + """ + client: "Client" + pending_futures: Set[Future] + job_resources: Optional[Dict[str, Any]] + is_shutting_down = False def __init__( - self, - client: "Client", + self, client: "Client", job_resources: Optional[Dict[str, Any]] = None ) -> None: self.client = client + self.pending_futures = set() + self.job_resources = job_resources + + if self.job_resources is not None: + # `mem` needs to be a number for dask, so we need to parse it + if "mem" in self.job_resources: + self.job_resources["mem"] = _parse_mem(self.job_resources["mem"]) + if "cpus-per-task" in self.job_resources: + self.job_resources["cpus"] = int( + self.job_resources.pop("cpus-per-task") + ) + + # Clean up if a SIGINT signal is received. However, do not interfere with the + # existing signal handler of the process or the + # shutdown of the main process which sends SIGTERM signals to terminate all + # child processes. + existing_sigint_handler = signal.getsignal(signal.SIGINT) + signal.signal( + signal.SIGINT, + partial(_handle_kill_through_weakref, ref(self), existing_sigint_handler), + ) @classmethod def from_config( cls, - job_resources: Dict[str, Any], + job_resources: Dict[str, str], + **_kwargs: Any, ) -> "DaskExecutor": from distributed import Client - return cls(Client(**job_resources)) + job_resources = job_resources.copy() + address = job_resources.pop("address", None) + if address is None: + address = os.environ.get("DASK_ADDRESS", None) + + client = Client(address=address) + return cls(client, job_resources=job_resources) @classmethod def as_completed(cls, futures: List["Future[_T]"]) -> Iterator["Future[_T]"]: @@ -72,7 +174,20 @@ def submit( # type: ignore[override] __fn, ), ) - fut = self.client.submit(partial(__fn, *args, **kwargs)) + + kwargs["__env"] = os.environ.copy() + + # We run the functions in dask as a separate process to not hold the + # GIL for too long, because dask workers need to be able to communicate + # with the scheduler regularly. + __fn = partial(_run_with_nanny, __fn) + + fut = self.client.submit( + partial(__fn, *args, **kwargs), pure=False, resources=self.job_resources + ) + + self.pending_futures.add(fut) + fut.add_done_callback(self.pending_futures.remove) enrich_future_with_uncaught_warning(fut) return fut @@ -125,8 +240,32 @@ def map( # type: ignore[override] def forward_log(self, fut: "Future[_T]") -> _T: return fut.result() + def handle_kill( + self, + existing_sigint_handler: Any, + signum: Optional[int], + frame: Any, + ) -> None: + if self.is_shutting_down: + return + + self.is_shutting_down = True + + self.client.cancel(list(self.pending_futures)) + + if ( + existing_sigint_handler # pylint: disable=comparison-with-callable + != signal.default_int_handler + and callable(existing_sigint_handler) # Could also be signal.SIG_IGN + ): + existing_sigint_handler(signum, frame) + def shutdown(self, wait: bool = True, *, cancel_futures: bool = False) -> None: + print(f"{wait=} {cancel_futures=}") + traceback.print_stack() if wait: - self.client.close(timeout=60 * 60 * 24) + for fut in list(self.pending_futures): + fut.result() + self.client.close(timeout=60 * 60) # 1 hour else: self.client.close() diff --git a/cluster_tools/cluster_tools/schedulers/cluster_executor.py b/cluster_tools/cluster_tools/schedulers/cluster_executor.py index ce3627bb6..a52cb87ad 100644 --- a/cluster_tools/cluster_tools/schedulers/cluster_executor.py +++ b/cluster_tools/cluster_tools/schedulers/cluster_executor.py @@ -23,6 +23,7 @@ Union, cast, ) +from weakref import ReferenceType, ref from typing_extensions import ParamSpec @@ -45,6 +46,18 @@ _S = TypeVar("_S") +def _handle_kill_through_weakref( + executor_ref: "ReferenceType[ClusterExecutor]", + existing_sigint_handler: Any, + signum: Optional[int], + frame: Any, +) -> None: + executor = executor_ref() + if executor is None: + return + executor.handle_kill(existing_sigint_handler, signum, frame) + + def join_messages(strings: List[str]) -> str: return " ".join(x.strip() for x in strings if x.strip()) @@ -130,7 +143,10 @@ def __init__( # shutdown of the main process which sends SIGTERM signals to terminate all # child processes. existing_sigint_handler = signal.getsignal(signal.SIGINT) - signal.signal(signal.SIGINT, partial(self.handle_kill, existing_sigint_handler)) + signal.signal( + signal.SIGINT, + partial(_handle_kill_through_weakref, ref(self), existing_sigint_handler), + ) self.meta_data = {} assert not ( diff --git a/cluster_tools/tests/test_all.py b/cluster_tools/tests/test_all.py index 79d3df652..dbb71190c 100644 --- a/cluster_tools/tests/test_all.py +++ b/cluster_tools/tests/test_all.py @@ -14,7 +14,6 @@ from distributed import LocalCluster import cluster_tools -from cluster_tools.executors.dask import DaskExecutor # "Worker" functions. @@ -79,10 +78,14 @@ def get_executors(with_debug_sequential: bool = False) -> List[cluster_tools.Exe executors.append(cluster_tools.get_executor("sequential")) if "dask" in executor_keys: if not _dask_cluster: - from distributed import LocalCluster + from distributed import LocalCluster, Worker - _dask_cluster = LocalCluster() - executors.append(cluster_tools.get_executor("dask", address=_dask_cluster)) + _dask_cluster = LocalCluster( + worker_class=Worker, resources={"mem": 20e9, "cpus": 4}, nthreads=6 + ) + executors.append( + cluster_tools.get_executor("dask", job_resources={"address": _dask_cluster}) + ) if "test_pickling" in executor_keys: executors.append(cluster_tools.get_executor("test_pickling")) if "pbs" in executor_keys: @@ -328,7 +331,8 @@ def run_map(executor: cluster_tools.Executor) -> None: assert list(result) == [4, 9, 16] for exc in get_executors(): - run_map(exc) + if not isinstance(exc, cluster_tools.DaskExecutor): + run_map(exc) def test_executor_args() -> None: