Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More dask features #959

Merged
merged 15 commits into from
Nov 16, 2023
131 changes: 125 additions & 6 deletions cluster_tools/cluster_tools/executors/dask.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import os
import re
import signal
import traceback
from concurrent import futures
from concurrent.futures import Future
from functools import partial
from multiprocessing import Queue, get_context
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -11,6 +15,7 @@
Iterator,
List,
Optional,
Set,
TypeVar,
cast,
)
Expand All @@ -28,23 +33,103 @@
_S = TypeVar("_S")


def _run_in_nanny(
queue: Queue, __fn: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> None:
try:
__env = cast(Dict[str, str], kwargs.pop("__env"))
for key, value in __env.items():
os.environ[key] = value

ret = __fn(*args, **kwargs)
queue.put({"value": ret})
except Exception as exc:
queue.put({"exception": exc})


def _run_with_nanny(
__fn: Callable[_P, _T],
*args: _P.args,
**kwargs: _P.kwargs,
) -> _T:
mp_context = get_context("spawn")
q = mp_context.Queue()
p = mp_context.Process(target=_run_in_nanny, args=(q, __fn) + args, kwargs=kwargs)
p.start()
p.join()
ret = q.get(timeout=0.1)
if "exception" in ret:
raise ret["exception"]
else:
return ret["value"]


def _parse_mem(size: str) -> int:
units = {"": 1, "K": 2**10, "M": 2**20, "G": 2**30, "T": 2**40}
m = re.match(r"^([\d\.]+)\s*([a-zA-Z]{0,3})$", str(size).strip())
normanrz marked this conversation as resolved.
Show resolved Hide resolved
assert m is not None
number, unit = float(m.group(1)), m.group(2).upper()
return int(number * units[unit])


class DaskExecutor(futures.Executor):
"""
The `DaskExecutor` allows to run workloads on a dask cluster.

The executor can be constructed with an existing dask `Client` or
from a declarative configuration. The address of the dask scheduler
can be part of the configuration or supplied as environment variable
`DASK_ADDRESS`.

There is support for resource-based scheduling. As default, `mem` and
`cpus-per-task` are supported. To make use of them, the dask workers
should be started with:
`python -m dask worker --no-nanny --nthreads 6 tcp://... --resources "mem=1073741824 cpus=8"`
"""

client: "Client"
pending_futures: Set[Future]
job_resources: Optional[Dict[str, Any]]
is_shutting_down = False

def __init__(
self,
client: "Client",
self, client: "Client", job_resources: Optional[Dict[str, Any]] = None
) -> None:
self.client = client
self.pending_futures = set()
self.job_resources = job_resources

if self.job_resources is not None:
# `mem` needs to be a number for dask, so we need to parse it
if "mem" in self.job_resources:
self.job_resources["mem"] = _parse_mem(self.job_resources["mem"])
if "cpus-per-task" in self.job_resources:
self.job_resources["cpus"] = int(
self.job_resources.pop("cpus-per-task")
)

# Clean up if a SIGINT signal is received. However, do not interfere with the
daniel-wer marked this conversation as resolved.
Show resolved Hide resolved
# existing signal handler of the process or the
# shutdown of the main process which sends SIGTERM signals to terminate all
# child processes.
existing_sigint_handler = signal.getsignal(signal.SIGINT)
signal.signal(signal.SIGINT, partial(self.handle_kill, existing_sigint_handler))

@classmethod
def from_config(
cls,
job_resources: Dict[str, Any],
job_resources: Dict[str, str],
**_kwargs: Any,
) -> "DaskExecutor":
from distributed import Client

return cls(Client(**job_resources))
job_resources = job_resources.copy()
address = job_resources.pop("address", None)
if address is None:
address = os.environ.get("DASK_ADDRESS", None)

client = Client(address=address)
return cls(client, job_resources=job_resources)

@classmethod
def as_completed(cls, futures: List["Future[_T]"]) -> Iterator["Future[_T]"]:
Expand Down Expand Up @@ -72,7 +157,20 @@ def submit( # type: ignore[override]
__fn,
),
)
fut = self.client.submit(partial(__fn, *args, **kwargs))

kwargs["__env"] = os.environ.copy()

# We run the functions in dask as a separate process to not hold the
# GIL for too long, because dask workers need to be able to communicate
# with the scheduler regularly.
__fn = partial(_run_with_nanny, __fn)

fut = self.client.submit(
partial(__fn, *args, **kwargs), pure=False, resources=self.job_resources
)

self.pending_futures.add(fut)
fut.add_done_callback(self.pending_futures.remove)

enrich_future_with_uncaught_warning(fut)
return fut
Expand Down Expand Up @@ -125,8 +223,29 @@ def map( # type: ignore[override]
def forward_log(self, fut: "Future[_T]") -> _T:
return fut.result()

def handle_kill(
self, existing_sigint_handler: Any, signum: Optional[int], frame: Any
) -> None:
if self.is_shutting_down:
return

self.is_shutting_down = True

self.client.cancel(list(self.pending_futures))

if (
existing_sigint_handler # pylint: disable=comparison-with-callable
!= signal.default_int_handler
and callable(existing_sigint_handler) # Could also be signal.SIG_IGN
):
existing_sigint_handler(signum, frame)

def shutdown(self, wait: bool = True, *, cancel_futures: bool = False) -> None:
print(f"{wait=} {cancel_futures=}")
traceback.print_stack()
if wait:
self.client.close(timeout=60 * 60 * 24)
for fut in list(self.pending_futures):
fut.result()
self.client.close(timeout=60 * 60) # 1 hour
else:
self.client.close()
Loading