From b9f6f07e1de38fde29a1c5effaf7e1063e64da25 Mon Sep 17 00:00:00 2001 From: Jonathan Striebel Date: Mon, 6 Feb 2023 15:06:06 +0100 Subject: [PATCH] cluster tools: typing & refactoring (#858) * python 3.7+, update mypy * unrelated commit * partial typing, organize executors into modules * more typing, refactor utils * try fn * Revert "try fn" This reverts commit 3016dfc476b7f46a80119e57f615080450dd3a1e. * ignore errors for py3.8 * fix Future subscription * allow unused kwargs in init * fix __cfut_options usage * CI: don't fail-fast for webknossos-linux tests * more typing * fixes * more fixes * some job-ids are ints * fix kube.py * ensure job_ids are strings * fix linter and tests * Update cluster_tools/cluster_tools/_utils/string.py Co-authored-by: Philipp Otto * apply PR feedback * rename to reflection * readd debugging var --------- Co-authored-by: Philipp Otto --- .github/workflows/ci.yml | 1 + cluster_tools/cluster_tools/__init__.py | 361 ++++-------------- cluster_tools/cluster_tools/_utils/call.py | 48 +++ .../cluster_tools/_utils/file_wait_thread.py | 106 +++++ .../multiprocessing_logging_handler.py | 18 +- .../cluster_tools/_utils/pickling.py | 52 +++ .../cluster_tools/_utils/reflection.py | 34 ++ cluster_tools/cluster_tools/_utils/string.py | 22 ++ .../cluster_tools/{ => _utils}/tailf.py | 2 +- cluster_tools/cluster_tools/_utils/warning.py | 68 ++++ .../executors/debug_sequential.py | 54 +++ .../executors/multiprocessing.py | 252 ++++++++++++ .../cluster_tools/executors/pickle.py | 37 ++ .../cluster_tools/executors/sequential.py | 29 ++ cluster_tools/cluster_tools/pickling.py | 70 ---- cluster_tools/cluster_tools/remote.py | 34 +- .../schedulers/cluster_executor.py | 240 +++++++----- .../cluster_tools/schedulers/kube.py | 61 ++- cluster_tools/cluster_tools/schedulers/pbs.py | 47 ++- .../cluster_tools/schedulers/slurm.py | 162 +++++--- cluster_tools/cluster_tools/util.py | 223 ----------- cluster_tools/poetry.lock | 117 +++--- cluster_tools/pyproject.toml | 7 +- cluster_tools/tests/test_all.py | 78 ++-- cluster_tools/tests/test_deref_main.py | 8 +- cluster_tools/tests/test_kubernetes.py | 9 +- cluster_tools/tests/test_multiprocessing.py | 22 +- cluster_tools/tests/test_slurm.py | 121 +++--- .../webknossos/dataset/_utils/pims_images.py | 4 + 29 files changed, 1307 insertions(+), 980 deletions(-) create mode 100644 cluster_tools/cluster_tools/_utils/call.py create mode 100644 cluster_tools/cluster_tools/_utils/file_wait_thread.py rename cluster_tools/cluster_tools/{ => _utils}/multiprocessing_logging_handler.py (89%) create mode 100644 cluster_tools/cluster_tools/_utils/pickling.py create mode 100644 cluster_tools/cluster_tools/_utils/reflection.py create mode 100644 cluster_tools/cluster_tools/_utils/string.py rename cluster_tools/cluster_tools/{ => _utils}/tailf.py (99%) create mode 100644 cluster_tools/cluster_tools/_utils/warning.py create mode 100644 cluster_tools/cluster_tools/executors/debug_sequential.py create mode 100644 cluster_tools/cluster_tools/executors/multiprocessing.py create mode 100644 cluster_tools/cluster_tools/executors/pickle.py create mode 100644 cluster_tools/cluster_tools/executors/sequential.py delete mode 100644 cluster_tools/cluster_tools/pickling.py delete mode 100644 cluster_tools/cluster_tools/util.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e775efe83..d73a03258 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -139,6 +139,7 @@ jobs: matrix: python-version: [3.7, 3.8, 3.9] group: [1, 2, 3] + fail-fast: false defaults: run: working-directory: webknossos diff --git a/cluster_tools/cluster_tools/__init__.py b/cluster_tools/cluster_tools/__init__.py index b3a0afb0e..46a07c542 100644 --- a/cluster_tools/cluster_tools/__init__.py +++ b/cluster_tools/cluster_tools/__init__.py @@ -7,306 +7,35 @@ from functools import partial from pathlib import Path from shutil import rmtree -from typing import Union +from typing import Any, Union, overload -from . import pickling -from .multiprocessing_logging_handler import get_multiprocessing_logging_setup_fn -from .schedulers.cluster_executor import ClusterExecutor, RemoteOutOfMemoryException -from .schedulers.kube import KubernetesExecutor -from .schedulers.pbs import PBSExecutor -from .schedulers.slurm import SlurmExecutor -from .util import enrich_future_with_uncaught_warning +from typing_extensions import Literal +from cluster_tools._utils.warning import enrich_future_with_uncaught_warning +from cluster_tools.executors.debug_sequential import DebugSequentialExecutor +from cluster_tools.executors.multiprocessing import MultiprocessingExecutor +from cluster_tools.executors.pickle import PickleExecutor +from cluster_tools.executors.sequential import SequentialExecutor +from cluster_tools.schedulers.cluster_executor import ( + ClusterExecutor, + RemoteOutOfMemoryException, +) +from cluster_tools.schedulers.kube import KubernetesExecutor +from cluster_tools.schedulers.pbs import PBSExecutor +from cluster_tools.schedulers.slurm import SlurmExecutor -def get_existent_kwargs_subset(whitelist, kwargs): - new_kwargs = {} - for arg_name in whitelist: - if arg_name in kwargs: - new_kwargs[arg_name] = kwargs[arg_name] +# For backwards-compatibility: +WrappedProcessPoolExecutor = MultiprocessingExecutor - return new_kwargs - -PROCESS_POOL_KWARGS_WHITELIST = ["max_workers", "initializer", "initargs"] - - -class WrappedProcessPoolExecutor(ProcessPoolExecutor): - """ - Wraps the ProcessPoolExecutor to add various features: - - map_to_futures and map_unordered method - - pickling of job's output (see output_pickle_path_getter and output_pickle_path) - - job submission via pickling to circumvent bug in python < 3.8 (see MULTIPROCESSING_VIA_IO_TMP_DIR) - """ - - def __init__(self, **kwargs): - assert (not "start_method" in kwargs or kwargs["start_method"] is None) or ( - not "mp_context" in kwargs - ), "Cannot use both `start_method` and `mp_context` kwargs." - - new_kwargs = get_existent_kwargs_subset(PROCESS_POOL_KWARGS_WHITELIST, kwargs) - - mp_context = None - - if "mp_context" in kwargs: - mp_context = kwargs["mp_context"] - elif "start_method" in kwargs and kwargs["start_method"] is not None: - mp_context = multiprocessing.get_context(kwargs["start_method"]) - elif "MULTIPROCESSING_DEFAULT_START_METHOD" in os.environ: - mp_context = multiprocessing.get_context( - os.environ["MULTIPROCESSING_DEFAULT_START_METHOD"] - ) - else: - mp_context = multiprocessing.get_context("spawn") - - new_kwargs["mp_context"] = mp_context - - ProcessPoolExecutor.__init__(self, **new_kwargs) - - def submit(self, *args, **kwargs): - - output_pickle_path = None - if "__cfut_options" in kwargs: - output_pickle_path = kwargs["__cfut_options"]["output_pickle_path"] - del kwargs["__cfut_options"] - - if os.environ.get("MULTIPROCESSING_VIA_IO"): - # If MULTIPROCESSING_VIA_IO is set, _submit_via_io is used to - # workaround size constraints in pythons multiprocessing - # implementation. Also see https://github.com/python/cpython/pull/10305/files - # This should be fixed in python 3.8 - submit_fn = self._submit_via_io - else: - submit_fn = super().submit - - # Depending on the start_method and output_pickle_path, wrapper functions may need to be - # executed in the new process context, before the actual code is ran. - # These wrapper functions consume their arguments from *args, **kwargs and assume - # that the next argument will be another function that is then called. - # The call_stack holds all of these wrapper functions and their arguments in the correct order. - # For example, call_stack = [wrapper_fn_1, wrapper_fn_1_arg_1, wrapper_fn_2, actual_fn, actual_fn_arg_1] - # where wrapper_fn_1 is called, which eventually calls wrapper_fn_2, which eventually calls actual_fn. - call_stack = [] - - if self._mp_context.get_start_method() != "fork": - # If a start_method other than the default "fork" is used, logging needs to be re-setup, - # because the programming context is not inherited in those cases. - multiprocessing_logging_setup_fn = get_multiprocessing_logging_setup_fn() - call_stack.extend( - [ - WrappedProcessPoolExecutor._setup_logging_and_execute, - multiprocessing_logging_setup_fn, - ] - ) - - if output_pickle_path is not None: - call_stack.extend( - [ - WrappedProcessPoolExecutor._execute_and_persist_function, - output_pickle_path, - ] - ) - - fut = submit_fn(*call_stack, *args, **kwargs) - - enrich_future_with_uncaught_warning(fut) - return fut - - def _submit_via_io(self, *args, **kwargs): - - func = args[0] - args = args[1:] - - opt_tmp_dir = os.environ.get("MULTIPROCESSING_VIA_IO_TMP_DIR") - if opt_tmp_dir is not None: - dirpath = tempfile.mkdtemp(dir=opt_tmp_dir) - else: - dirpath = tempfile.mkdtemp() - - output_pickle_path = Path(dirpath) / "jobdescription.pickle" - - with open(output_pickle_path, "wb") as file: - pickling.dump((func, args, kwargs), file) - - future = super().submit( - WrappedProcessPoolExecutor._execute_via_io, output_pickle_path - ) - - future.add_done_callback( - partial(WrappedProcessPoolExecutor._remove_tmp_file, dirpath) - ) - - return future - - @staticmethod - def _remove_tmp_file(path, _future): - rmtree(path) - - @staticmethod - def _setup_logging_and_execute(multiprocessing_logging_setup_fn, *args, **kwargs): - - func = args[0] - args = args[1:] - - multiprocessing_logging_setup_fn() - - return func(*args, **kwargs) - - @staticmethod - def _execute_via_io(serialized_function_info_path): - - with open(serialized_function_info_path, "rb") as file: - (func, args, kwargs) = pickling.load(file) - return func(*args, **kwargs) - - @staticmethod - def _execute_and_persist_function(output_pickle_path, *args, **kwargs): - - func = args[0] - args = args[1:] - - try: - result = True, func(*args, **kwargs) - except Exception as exc: - result = False, exc - logging.warning(f"Job computation failed with:\n{exc.__repr__()}") - - if result[0]: - # Only pickle the result in the success case, since the output - # is used as a checkpoint. - # Note that this behavior differs a bit from the cluster executor - # which will always serialize the output (even exceptions) to - # disk. However, the output will have a .preliminary prefix at first - # which is only removed in the success case so that a checkpoint at - # the desired target only exists if the job was successful. - with open(output_pickle_path, "wb") as file: - pickling.dump(result, file) - return result[1] - else: - raise result[1] - - def map_unordered(self, func, args): - - futs = self.map_to_futures(func, args) - - # Return a separate generator to avoid that map_unordered - # is executed lazily (otherwise, jobs would be submitted - # lazily, as well). - def result_generator(): - for fut in futures.as_completed(futs): - yield fut.result() - - return result_generator() - - def map_to_futures(self, func, args, output_pickle_path_getter=None): - - if output_pickle_path_getter is not None: - futs = [ - self.submit( - func, - arg, - __cfut_options={ - "output_pickle_path": output_pickle_path_getter(arg) - }, - ) - for arg in args - ] - else: - futs = [self.submit(func, arg) for arg in args] - - return futs - - def forward_log(self, fut): - """ - Similar to the cluster executor, this method Takes a future from which the log file is forwarded to the active - process. This method blocks as long as the future is not done. - """ - - # Since the default behavior of process pool executors is to show the log in the main process - # we don't need to do anything except for blocking until the future is done. - return fut.result() - - -class SequentialExecutor(WrappedProcessPoolExecutor): - """ - The same as WrappedProcessPoolExecutor, but always uses only one core. In essence, - this is a sequential executor approach, but it still makes use of the standard pool approach. - That way, switching between different executors should always work without any problems. - """ - - def __init__(self, **kwargs): - kwargs["max_workers"] = 1 - WrappedProcessPoolExecutor.__init__(self, **kwargs) - - -class DebugSequentialExecutor(SequentialExecutor): - """ - Only use for debugging purposes. This executor does not spawn new processes for its jobs. Therefore, - setting breakpoint()'s should be possible without context-related problems. - """ - - def submit(self, *args, **kwargs): - - output_pickle_path = None - if "__cfut_options" in kwargs: - output_pickle_path = kwargs["__cfut_options"]["output_pickle_path"] - del kwargs["__cfut_options"] - - if output_pickle_path is not None: - fut = self._blocking_submit( - WrappedProcessPoolExecutor._execute_and_persist_function, - output_pickle_path, - *args, - **kwargs, - ) - else: - fut = self._blocking_submit(*args, **kwargs) - - enrich_future_with_uncaught_warning(fut) - return fut - - def _blocking_submit(self, *args, **kwargs): - - func = args[0] - args = args[1:] - - fut = futures.Future() - result = func(*args, **kwargs) - fut.set_result(result) - - return fut - - -def pickle_identity(obj): - return pickling.loads(pickling.dumps(obj)) - - -def pickle_identity_executor(func, *args, **kwargs): - result = func(*args, **kwargs) - return pickle_identity(result) - - -class PickleExecutor(WrappedProcessPoolExecutor): - """ - The same as WrappedProcessPoolExecutor, but always pickles input and output of the jobs. - When using this executor for automated tests, it is ensured that using cluster executors in production - won't provoke pickling-related problems. - """ - - def submit(self, _func, *_args, **_kwargs): - - (func, args, kwargs) = pickle_identity((_func, _args, _kwargs)) - return super().submit(pickle_identity_executor, func, *args, **kwargs) - - -def noop(): +def _noop() -> bool: return True did_start_test_multiprocessing = False -def test_valid_multiprocessing(): - +def _test_valid_multiprocessing() -> None: msg = """ ############################################################### An attempt has been made to start a new process before the @@ -324,7 +53,7 @@ def test_valid_multiprocessing(): with get_executor("multiprocessing") as executor: try: - res_fut = executor.submit(noop) + res_fut = executor.submit(_noop) assert res_fut.result() == True, msg except RuntimeError as exc: raise Exception(msg) from exc @@ -332,8 +61,52 @@ def test_valid_multiprocessing(): raise Exception(msg) from exc -def get_executor(environment, **kwargs): +@overload +def get_executor(environment: Literal["slurm"], **kwargs: Any) -> SlurmExecutor: + ... + + +@overload +def get_executor(environment: Literal["pbs"], **kwargs: Any) -> PBSExecutor: + ... + + +@overload +def get_executor( + environment: Literal["kubernetes"], **kwargs: Any +) -> KubernetesExecutor: + ... + + +@overload +def get_executor( + environment: Literal["multiprocessing"], **kwargs: Any +) -> MultiprocessingExecutor: + ... + + +@overload +def get_executor( + environment: Literal["sequential"], **kwargs: Any +) -> SequentialExecutor: + ... + + +@overload +def get_executor( + environment: Literal["debug_sequential"], **kwargs: Any +) -> DebugSequentialExecutor: + ... + + +@overload +def get_executor( + environment: Literal["test_pickling"], **kwargs: Any +) -> PickleExecutor: + ... + +def get_executor(environment: str, **kwargs: Any) -> "Executor": if environment == "slurm": return SlurmExecutor(**kwargs) elif environment == "pbs": @@ -344,9 +117,9 @@ def get_executor(environment, **kwargs): global did_start_test_multiprocessing if not did_start_test_multiprocessing: did_start_test_multiprocessing = True - test_valid_multiprocessing() + _test_valid_multiprocessing() - return WrappedProcessPoolExecutor(**kwargs) + return MultiprocessingExecutor(**kwargs) elif environment == "sequential": return SequentialExecutor(**kwargs) elif environment == "debug_sequential": @@ -356,4 +129,4 @@ def get_executor(environment, **kwargs): raise Exception("Unknown executor: {}".format(environment)) -Executor = Union[ClusterExecutor, WrappedProcessPoolExecutor] +Executor = Union[ClusterExecutor, MultiprocessingExecutor] diff --git a/cluster_tools/cluster_tools/_utils/call.py b/cluster_tools/cluster_tools/_utils/call.py new file mode 100644 index 000000000..f726b699a --- /dev/null +++ b/cluster_tools/cluster_tools/_utils/call.py @@ -0,0 +1,48 @@ +import subprocess +from typing import Optional, Tuple + + +def call(command: str, stdin: Optional[str] = None) -> Tuple[str, str, int]: + """Invokes a shell command as a subprocess, optionally with some + data sent to the standard input. Returns the standard output data, + the standard error, and the return code. + """ + if stdin is not None: + stdin_flag = subprocess.PIPE + else: + stdin_flag = None + p = subprocess.run( + command, + shell=True, + capture_output=True, + text=True, + ) + return p.stdout, p.stderr, p.returncode + + +class CommandError(Exception): + """Raised when a shell command exits abnormally.""" + + def __init__( + self, command: str, code: int, stderr: str + ): # pylint: disable=super-init-not-called + self.command = command + self.code = code + self.stderr = stderr + + def __str__(self) -> str: + return "%s exited with status %i: %s" % ( + repr(self.command), + self.code, + repr(self.stderr), + ) + + +def chcall(command: str, stdin: Optional[str] = None) -> Tuple[str, str]: + """Like ``call`` but raises an exception when the return code is + nonzero. Only returns the stdout and stderr data. + """ + stdout, stderr, code = call(command, stdin) + if code != 0: + raise CommandError(command, code, stderr) + return stdout, stderr diff --git a/cluster_tools/cluster_tools/_utils/file_wait_thread.py b/cluster_tools/cluster_tools/_utils/file_wait_thread.py new file mode 100644 index 000000000..2bdb8046f --- /dev/null +++ b/cluster_tools/cluster_tools/_utils/file_wait_thread.py @@ -0,0 +1,106 @@ +import logging +import os +import threading +import time +from typing import TYPE_CHECKING, Callable, Dict, Tuple + +if TYPE_CHECKING: + from cluster_tools.schedulers.cluster_executor import ClusterExecutor + + +class FileWaitThread(threading.Thread): + """A thread that polls the filesystem waiting for a list of files to + be created. When a specified file is created, it invokes a callback. + """ + + MAX_RETRY = 30 + + def __init__( + self, + callback: Callable[[str, bool], None], + executor: "ClusterExecutor", + interval: int = 2, + ): + """The callable ``callback`` will be invoked with value + associated with the filename of each file that is created. + ``interval`` specifies the polling rate. + """ + threading.Thread.__init__(self) + self.callback = callback + self.interval = interval + self.waiting: Dict[str, str] = {} + self.retryMap: Dict[str, int] = {} + self.lock = threading.Lock() + self.shutdown = False + self.executor = executor + + def stop(self) -> None: + """Stop the thread soon.""" + with self.lock: + self.shutdown = True + + def waitFor(self, filename: str, value: str) -> None: + """Adds a new filename (and its associated callback value) to + the set of files being waited upon. + """ + with self.lock: + self.waiting[filename] = value + + def run(self) -> None: + def handle_completed_job( + job_id: str, filename: str, failed_early: bool + ) -> None: + self.callback(job_id, failed_early) + del self.waiting[filename] + + while True: + with self.lock: + if self.shutdown: + return + + pending_tasks = self.executor.get_pending_tasks() + + # Poll for each file. + for filename in list(self.waiting): + job_id = self.waiting[filename] + if job_id in pending_tasks: + # Don't check status of pending tasks, since this + # can vastly slow down the polling. + continue + + if os.path.exists(filename): + # Check for output file as a fast indicator for job completion + handle_completed_job(job_id, filename, False) + elif self.executor is not None: + status = self.executor.check_job_state(job_id) + + # We have to re-check for the output file since this could be created in the mean time + if os.path.exists(filename): + handle_completed_job(job_id, filename, False) + else: + if status == "completed": + self.retryMap[filename] = self.retryMap.get(filename, 0) + self.retryMap[filename] += 1 + + if self.retryMap[filename] <= FileWaitThread.MAX_RETRY: + # Retry by looping again + logging.warning( + "Job state is completed, but {} couldn't be found. Retrying {}/{}".format( + filename, + self.retryMap[filename], + FileWaitThread.MAX_RETRY, + ) + ) + else: + logging.error( + "Job state is completed, but {} couldn't be found.".format( + filename + ) + ) + handle_completed_job(job_id, filename, True) + + elif status == "failed": + handle_completed_job(job_id, filename, True) + elif status == "ignore": + pass + time.sleep(self.interval) diff --git a/cluster_tools/cluster_tools/multiprocessing_logging_handler.py b/cluster_tools/cluster_tools/_utils/multiprocessing_logging_handler.py similarity index 89% rename from cluster_tools/cluster_tools/multiprocessing_logging_handler.py rename to cluster_tools/cluster_tools/_utils/multiprocessing_logging_handler.py index 17c135455..7a148e336 100644 --- a/cluster_tools/cluster_tools/multiprocessing_logging_handler.py +++ b/cluster_tools/cluster_tools/_utils/multiprocessing_logging_handler.py @@ -14,7 +14,7 @@ # Inspired by https://stackoverflow.com/a/894284 -class MultiProcessingHandler(logging.Handler): +class _MultiprocessingLoggingHandler(logging.Handler): """This class wraps a logging handler and instantiates a multiprocessing queue. It asynchronously receives messages from the queue and emits them using the wrapped handler. The queue can be used by logging.QueueHandlers in other processes @@ -23,7 +23,7 @@ class MultiProcessingHandler(logging.Handler): the logging context is not copied to the subprocess but instead logging needs to be re-setup. """ - def __init__(self, name: str, wrapped_handler: logging.Handler): + def __init__(self, name: str, wrapped_handler: logging.Handler) -> None: super().__init__() self.wrapped_handler = wrapped_handler @@ -88,7 +88,7 @@ def _setup_logging_multiprocessing( fork is used) by setting up QueueHandler loggers for each queue and level so that log messages are piped to the original loggers in the main process. """ - warnings.filters = filters # type: ignore[attr-defined] + warnings.filters = filters root_logger = getLogger() for handler in root_logger.handlers: @@ -101,16 +101,16 @@ def _setup_logging_multiprocessing( root_logger.addHandler(handler) -def get_multiprocessing_logging_setup_fn() -> Any: +def _get_multiprocessing_logging_setup_fn() -> Any: root_logger = getLogger() queues = [] levels = [] for i, handler in enumerate(list(root_logger.handlers)): - # Wrap logging handlers in MultiProcessingHandlers to make them work in a multiprocessing setup + # Wrap logging handlers in _MultiprocessingLoggingHandlers to make them work in a multiprocessing setup # when using start_methods other than fork, for example, spawn or forkserver - if not isinstance(handler, MultiProcessingHandler): - mp_handler = MultiProcessingHandler( + if not isinstance(handler, _MultiprocessingLoggingHandler): + mp_handler = _MultiprocessingLoggingHandler( f"multi-processing-handler-{i}", handler ) @@ -123,11 +123,11 @@ def get_multiprocessing_logging_setup_fn() -> Any: levels.append(mp_handler.level) # Return a logging setup function that when called will setup QueueHandler loggers - # reusing the queues of each wrapped MultiProcessingHandler. This way all log messages + # reusing the queues of each wrapped _MultiprocessingLoggingHandler. This way all log messages # are forwarded to the main process. return functools.partial( _setup_logging_multiprocessing, queues, levels, - filters=warnings.filters, # type: ignore[attr-defined] + filters=warnings.filters, ) diff --git a/cluster_tools/cluster_tools/_utils/pickling.py b/cluster_tools/cluster_tools/_utils/pickling.py new file mode 100644 index 000000000..bca5a1ac3 --- /dev/null +++ b/cluster_tools/cluster_tools/_utils/pickling.py @@ -0,0 +1,52 @@ +import pickle +import sys +from typing import Any, BinaryIO, Optional + +from cluster_tools._utils.warning import warn_after + +WARNING_TIMEOUT = 10 * 60 # seconds + + +def _get_suitable_pickle_protocol() -> int: + # Protocol 4 allows to serialize objects larger than 4 GiB, but is only supported + # beginning from Python 3.4 + protocol = 4 if sys.version_info[0] >= 3 and sys.version_info[1] >= 4 else 3 + return protocol + + +@warn_after("pickle.dumps", WARNING_TIMEOUT) +def dumps(*args: Any, **kwargs: Any) -> bytes: + return pickle.dumps(*args, protocol=_get_suitable_pickle_protocol(), **kwargs) # type: ignore[misc] + + +@warn_after("pickle.dump", WARNING_TIMEOUT) +def dump(*args: Any, **kwargs: Any) -> None: + pickle.dump(*args, protocol=_get_suitable_pickle_protocol(), **kwargs) # type: ignore[misc] + + +@warn_after("pickle.loads", WARNING_TIMEOUT) +def loads(*args: Any, **kwargs: Any) -> Any: + assert ( + "custom_main_path" not in kwargs + ), "loads does not implement support for the argument custom_main_path" + return pickle.loads(*args, **kwargs) + + +class _RenameUnpickler(pickle.Unpickler): + custom_main_path: Optional[str] + + def find_class(self, module: str, name: str) -> Any: + renamed_module = module + if module == "__main__" and self.custom_main_path is not None: + renamed_module = self.custom_main_path + + return super(_RenameUnpickler, self).find_class(renamed_module, name) + + +@warn_after("pickle.load", WARNING_TIMEOUT) +def load(f: BinaryIO, custom_main_path: Optional[str] = None) -> Any: + unpickler = _RenameUnpickler(f) + unpickler.custom_main_path = ( # pylint: disable=attribute-defined-outside-init + custom_main_path + ) + return unpickler.load() diff --git a/cluster_tools/cluster_tools/_utils/reflection.py b/cluster_tools/cluster_tools/_utils/reflection.py new file mode 100644 index 000000000..ed0ea89b1 --- /dev/null +++ b/cluster_tools/cluster_tools/_utils/reflection.py @@ -0,0 +1,34 @@ +import os +import pickle +import sys +from typing import Callable + +WARNING_TIMEOUT = 10 * 60 # seconds + + +def file_path_to_absolute_module(file_path: str) -> str: + """ + Given a file path, return an import path. + :param file_path: A file path. + :return: + """ + assert os.path.exists(file_path) + file_loc, _ = os.path.splitext(file_path) + directory, module = os.path.split(file_loc) + module_path = [module] + while True: + if os.path.exists(os.path.join(directory, "__init__.py")): + directory, package = os.path.split(directory) + module_path.append(package) + else: + break + path = ".".join(module_path[::-1]) + return path + + +def get_function_name(fun: Callable) -> str: + # When using functools.partial, __name__ does not exist + try: + return fun.__name__ if hasattr(fun, "__name__") else get_function_name(fun.func) # type: ignore[attr-defined] + except Exception: + return "" diff --git a/cluster_tools/cluster_tools/_utils/string.py b/cluster_tools/cluster_tools/_utils/string.py new file mode 100644 index 000000000..1c5f573c6 --- /dev/null +++ b/cluster_tools/cluster_tools/_utils/string.py @@ -0,0 +1,22 @@ +import os +import random +import string + + +def local_filename(filename: str = "") -> str: + return os.path.join(os.getenv("CFUT_DIR", ".cfut"), filename) + + +# Instantiate a dedicated generator to avoid being dependent on +# the global seed which some external code might have set. +random_generator = random.Random() + + +def random_string( + length: int = 32, chars: str = (string.ascii_letters + string.digits) +) -> str: + return "".join(random_generator.choice(chars) for i in range(length)) + + +def with_preliminary_postfix(name: str) -> str: + return f"{name}.preliminary" diff --git a/cluster_tools/cluster_tools/tailf.py b/cluster_tools/cluster_tools/_utils/tailf.py similarity index 99% rename from cluster_tools/cluster_tools/tailf.py rename to cluster_tools/cluster_tools/_utils/tailf.py index e68b399b6..2e7add67f 100644 --- a/cluster_tools/cluster_tools/tailf.py +++ b/cluster_tools/cluster_tools/_utils/tailf.py @@ -8,7 +8,7 @@ from typing import Any, Callable -class Tail(object): +class Tail: """Represents a tail command.""" def __init__( diff --git a/cluster_tools/cluster_tools/_utils/warning.py b/cluster_tools/cluster_tools/_utils/warning.py new file mode 100644 index 000000000..d708ae09d --- /dev/null +++ b/cluster_tools/cluster_tools/_utils/warning.py @@ -0,0 +1,68 @@ +import logging +import os +import string +import threading +import time +from concurrent.futures import Future +from typing import Callable, TypeVar + +from typing_extensions import ParamSpec + +_P = ParamSpec("_P") +_T = TypeVar("_T") + + +def warn_after( + job: str, seconds: int +) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: + """ + Use as decorator to warn when a function is taking longer than {seconds} seconds. + """ + + def outer(fn: Callable[_P, _T]) -> Callable[_P, _T]: + def inner(*args: _P.args, **kwargs: _P.kwargs) -> _T: + exceeded_timeout = [False] + start_time = time.time() + + def warn_function() -> None: + logging.warning( + "Function {} is taking suspiciously long (longer than {} seconds)".format( + job, seconds + ) + ) + exceeded_timeout[0] = True + + timer = threading.Timer(seconds, warn_function) + timer.start() + + try: + result = fn(*args, **kwargs) + if exceeded_timeout[0]: + end_time = time.time() + logging.warning( + "Function {} succeeded after all (took {} seconds)".format( + job, int(end_time - start_time) + ) + ) + finally: + timer.cancel() + return result + + return inner + + return outer + + +def enrich_future_with_uncaught_warning(f: Future) -> None: + def warn_on_exception(future: Future) -> None: + maybe_exception = future.exception() + if maybe_exception is not None: + logging.error( + "A future crashed with an exception: {}. Future: {}".format( + maybe_exception, future + ) + ) + + if not hasattr(f, "is_wrapped_by_cluster_tools"): + f.is_wrapped_by_cluster_tools = True # type: ignore[attr-defined] + f.add_done_callback(warn_on_exception) diff --git a/cluster_tools/cluster_tools/executors/debug_sequential.py b/cluster_tools/cluster_tools/executors/debug_sequential.py new file mode 100644 index 000000000..0a6f3514f --- /dev/null +++ b/cluster_tools/cluster_tools/executors/debug_sequential.py @@ -0,0 +1,54 @@ +from concurrent.futures import Future +from multiprocessing.context import BaseContext +from typing import Any, Callable, Optional, Tuple, TypeVar, cast + +from typing_extensions import ParamSpec + +from cluster_tools._utils.warning import enrich_future_with_uncaught_warning +from cluster_tools.executors.multiprocessing import CFutDict, MultiprocessingExecutor +from cluster_tools.executors.sequential import SequentialExecutor + +_T = TypeVar("_T") +_P = ParamSpec("_P") + + +class DebugSequentialExecutor(SequentialExecutor): + """ + Only use for debugging purposes. This executor does not spawn new processes for its jobs. Therefore, + setting breakpoint()'s should be possible without context-related problems. + """ + + def submit( # type: ignore[override] + self, + __fn: Callable[_P, _T], + *args: _P.args, + **kwargs: _P.kwargs, + ) -> "Future[_T]": + if "__cfut_options" in kwargs: + output_pickle_path = cast(CFutDict, kwargs["__cfut_options"])[ + "output_pickle_path" + ] + del kwargs["__cfut_options"] + fut = self._blocking_submit( + MultiprocessingExecutor._execute_and_persist_function, # type: ignore[arg-type] + output_pickle_path, # type: ignore[arg-type] + __fn, # type: ignore[arg-type] + *args, + **kwargs, + ) + else: + fut = self._blocking_submit(__fn, *args, **kwargs) + + enrich_future_with_uncaught_warning(fut) + return fut + + def _blocking_submit( + self, + __fn: Callable[_P, _T], + *args: _P.args, + **kwargs: _P.kwargs, + ) -> "Future[_T]": + fut: "Future[_T]" = Future() + result = __fn(*args, **kwargs) + fut.set_result(result) + return fut diff --git a/cluster_tools/cluster_tools/executors/multiprocessing.py b/cluster_tools/cluster_tools/executors/multiprocessing.py new file mode 100644 index 000000000..d48cf2079 --- /dev/null +++ b/cluster_tools/cluster_tools/executors/multiprocessing.py @@ -0,0 +1,252 @@ +import logging +import multiprocessing +import os +import sys +import tempfile +from concurrent import futures +from concurrent.futures import Future, ProcessPoolExecutor +from functools import partial +from multiprocessing.context import BaseContext +from pathlib import Path +from shutil import rmtree +from typing import ( + Any, + Callable, + Dict, + Iterable, + Iterator, + List, + Optional, + Tuple, + TypeVar, + cast, +) + +from typing_extensions import Literal, ParamSpec, TypedDict + +from cluster_tools._utils import pickling +from cluster_tools._utils.multiprocessing_logging_handler import ( + _get_multiprocessing_logging_setup_fn, +) +from cluster_tools._utils.warning import enrich_future_with_uncaught_warning + + +class CFutDict(TypedDict): + output_pickle_path: str + + +_T = TypeVar("_T") +_P = ParamSpec("_P") +_S = TypeVar("_S") + + +class MultiprocessingExecutor(ProcessPoolExecutor): + """ + Wraps the ProcessPoolExecutor to add various features: + - map_to_futures and map_unordered method + - pickling of job's output (see output_pickle_path_getter and output_pickle_path) + - job submission via pickling to circumvent bug in python < 3.8 (see MULTIPROCESSING_VIA_IO_TMP_DIR) + """ + + _mp_context: BaseContext + + def __init__( + self, + *, + max_workers: Optional[int] = None, + start_method: Optional[str] = None, + mp_context: Optional[BaseContext] = None, + initializer: Optional[Callable] = None, + initargs: Tuple = (), + **__kwargs: Any, + ) -> None: + if mp_context is None: + if start_method is not None: + mp_context = multiprocessing.get_context(start_method) + elif "MULTIPROCESSING_DEFAULT_START_METHOD" in os.environ: + mp_context = multiprocessing.get_context( + os.environ["MULTIPROCESSING_DEFAULT_START_METHOD"] + ) + else: + mp_context = multiprocessing.get_context("spawn") + else: + assert ( + start_method is None + ), "Cannot use both `start_method` and `mp_context` kwargs." + + super().__init__( + mp_context=mp_context, + max_workers=max_workers, + initializer=initializer, + initargs=initargs, + ) + + def submit( # type: ignore[override] + self, + __fn: Callable[_P, _T], + *args: _P.args, + **kwargs: _P.kwargs, + ) -> "Future[_T]": + if "__cfut_options" in kwargs: + output_pickle_path = cast(CFutDict, kwargs["__cfut_options"])[ + "output_pickle_path" + ] + del kwargs["__cfut_options"] + else: + output_pickle_path = None + + if os.environ.get("MULTIPROCESSING_VIA_IO"): + # If MULTIPROCESSING_VIA_IO is set, _submit_via_io is used to + # workaround size constraints in pythons multiprocessing + # implementation. Also see https://github.com/python/cpython/pull/10305/files + # This should be fixed in python 3.8 + submit_fn = self._submit_via_io + else: + submit_fn = super().submit # type: ignore[assignment] + + # Depending on the start_method and output_pickle_path, wrapper functions may need to be + # executed in the new process context, before the actual code is ran. + # These wrapper functions consume their arguments from *args, **kwargs and assume + # that the next argument will be another function that is then called. + # The call_stack holds all of these wrapper functions and their arguments in the correct order. + # For example, call_stack = [wrapper_fn_1, wrapper_fn_1_arg_1, wrapper_fn_2, actual_fn, actual_fn_arg_1] + # where wrapper_fn_1 is called, which eventually calls wrapper_fn_2, which eventually calls actual_fn. + call_stack = [] + + if self._mp_context.get_start_method() != "fork": + # If a start_method other than the default "fork" is used, logging needs to be re-setup, + # because the programming context is not inherited in those cases. + multiprocessing_logging_setup_fn = _get_multiprocessing_logging_setup_fn() + call_stack.extend( + [ + MultiprocessingExecutor._setup_logging_and_execute, + multiprocessing_logging_setup_fn, + ] + ) + + if output_pickle_path is not None: + call_stack.extend( + [ + MultiprocessingExecutor._execute_and_persist_function, + output_pickle_path, + ] + ) + + fut = submit_fn(*call_stack, __fn, *args, **kwargs) + + enrich_future_with_uncaught_warning(fut) + return fut + + def _submit_via_io( + self, + __fn: Callable[_P, _T], + *args: _P.args, + **kwargs: _P.kwargs, + ) -> "Future[_T]": + opt_tmp_dir = os.environ.get("MULTIPROCESSING_VIA_IO_TMP_DIR") + if opt_tmp_dir is not None: + dirpath = tempfile.mkdtemp(dir=opt_tmp_dir) + else: + dirpath = tempfile.mkdtemp() + + output_pickle_path = Path(dirpath) / "jobdescription.pickle" + + with open(output_pickle_path, "wb") as file: + pickling.dump((__fn, args, kwargs), file) + + future = super().submit( + MultiprocessingExecutor._execute_via_io, output_pickle_path + ) + + future.add_done_callback( + partial(MultiprocessingExecutor._remove_tmp_file, dirpath) + ) + + return future + + @staticmethod + def _remove_tmp_file(path: os.PathLike, _future: Future) -> None: + rmtree(path) + + @staticmethod + def _setup_logging_and_execute( + multiprocessing_logging_setup_fn: Callable[[], None], + fn: Callable[_P, "Future[_T]"], + *args: Any, + **kwargs: Any, + ) -> "Future[_T]": + multiprocessing_logging_setup_fn() + return fn(*args, **kwargs) + + @staticmethod + def _execute_via_io(serialized_function_info_path: os.PathLike) -> Any: + with open(serialized_function_info_path, "rb") as file: + (fn, args, kwargs) = pickling.load(file) + return fn(*args, **kwargs) + + @staticmethod + def _execute_and_persist_function( + output_pickle_path: os.PathLike, + fn: Callable[_P, _T], + *args: _P.args, + **kwargs: _P.kwargs, + ) -> _T: + try: + result = fn(*args, **kwargs) + except Exception as exc: + logging.warning(f"Job computation failed with:\n{exc.__repr__()}") + raise exc + else: + # Only pickle the result in the success case, since the output + # is used as a checkpoint. + # Note that this behavior differs a bit from the cluster executor + # which will always serialize the output (even exceptions) to + # disk. However, the output will have a .preliminary prefix at first + # which is only removed in the success case so that a checkpoint at + # the desired target only exists if the job was successful. + with open(output_pickle_path, "wb") as file: + pickling.dump((True, result), file) + return result + + def map_unordered(self, fn: Callable[_P, _T], args: Any) -> Iterator[_T]: + futs: List["Future[_T]"] = self.map_to_futures(fn, args) + # Return a separate generator to avoid that map_unordered + # is executed lazily (otherwise, jobs would be submitted + # lazily, as well). + def result_generator() -> Iterator: + for fut in futures.as_completed(futs): + yield fut.result() + + return result_generator() + + def map_to_futures( + self, + fn: Callable[[_S], _T], + args: Iterable[_S], # TODO change: allow more than one arg per call + output_pickle_path_getter: Optional[Callable[[_S], os.PathLike]] = None, + ) -> List["Future[_T]"]: + if output_pickle_path_getter is not None: + futs = [ + self.submit( # type: ignore[call-arg] + fn, + arg, + __cfut_options={ + "output_pickle_path": output_pickle_path_getter(arg) + }, + ) + for arg in args + ] + else: + futs = [self.submit(fn, arg) for arg in args] + + return futs + + def forward_log(self, fut: "Future[_T]") -> _T: + """ + Similar to the cluster executor, this method Takes a future from which the log file is forwarded to the active + process. This method blocks as long as the future is not done. + """ + + # Since the default behavior of process pool executors is to show the log in the main process + # we don't need to do anything except for blocking until the future is done. + return fut.result() diff --git a/cluster_tools/cluster_tools/executors/pickle.py b/cluster_tools/cluster_tools/executors/pickle.py new file mode 100644 index 000000000..c6a62780f --- /dev/null +++ b/cluster_tools/cluster_tools/executors/pickle.py @@ -0,0 +1,37 @@ +from concurrent.futures import Future +from typing import Any, Callable, TypeVar + +from cluster_tools._utils import pickling +from cluster_tools.executors.multiprocessing import MultiprocessingExecutor + +_T = TypeVar("_T") + + +def _pickle_identity(obj: _T) -> _T: + return pickling.loads(pickling.dumps(obj)) + + +def _pickle_identity_executor(fn: Callable[..., _T], *args: Any, **kwargs: Any) -> _T: + result = fn(*args, **kwargs) + return _pickle_identity(result) + + +class PickleExecutor(MultiprocessingExecutor): + """ + The same as MultiprocessingExecutor, but always pickles input and output of the jobs. + When using this executor for automated tests, it is ensured that using cluster executors in production + won't provoke pickling-related problems. + """ + + def submit( # type: ignore[override] + self, + fn: Callable[..., _T], + *args: Any, + **kwargs: Any, + ) -> "Future[_T]": + (fn_pickled, args_pickled, kwargs_pickled) = _pickle_identity( + (fn, args, kwargs) + ) + return super().submit( + _pickle_identity_executor, fn_pickled, *args_pickled, **kwargs_pickled + ) diff --git a/cluster_tools/cluster_tools/executors/sequential.py b/cluster_tools/cluster_tools/executors/sequential.py new file mode 100644 index 000000000..6050ae5d7 --- /dev/null +++ b/cluster_tools/cluster_tools/executors/sequential.py @@ -0,0 +1,29 @@ +from multiprocessing.context import BaseContext +from typing import Any, Callable, Optional, Tuple + +from cluster_tools.executors.multiprocessing import MultiprocessingExecutor + + +class SequentialExecutor(MultiprocessingExecutor): + """ + The same as MultiprocessingExecutor, but always uses only one core. In essence, + this is a sequential executor approach, but it still makes use of the standard pool approach. + That way, switching between different executors should always work without any problems. + """ + + def __init__( + self, + *, + start_method: Optional[str] = None, + mp_context: Optional[BaseContext] = None, + initializer: Optional[Callable] = None, + initargs: Tuple = (), + **__kwargs: Any, + ) -> None: + super().__init__( + max_workers=1, + start_method=start_method, + mp_context=mp_context, + initializer=initializer, + initargs=initargs, + ) diff --git a/cluster_tools/cluster_tools/pickling.py b/cluster_tools/cluster_tools/pickling.py deleted file mode 100644 index 7a236ac3d..000000000 --- a/cluster_tools/cluster_tools/pickling.py +++ /dev/null @@ -1,70 +0,0 @@ -import os -import pickle -import sys - -from .util import warn_after - -WARNING_TIMEOUT = 10 * 60 # seconds - - -def file_path_to_absolute_module(file_path): - """ - Given a file path, return an import path. - :param file_path: A file path. - :return: - """ - assert os.path.exists(file_path) - file_loc, _ = os.path.splitext(file_path) - directory, module = os.path.split(file_loc) - module_path = [module] - while True: - if os.path.exists(os.path.join(directory, "__init__.py")): - directory, package = os.path.split(directory) - module_path.append(package) - else: - break - path = ".".join(module_path[::-1]) - return path - - -def get_suitable_pickle_protocol(): - # Protocol 4 allows to serialize objects larger than 4 GiB, but is only supported - # beginning from Python 3.4 - protocol = 4 if sys.version_info[0] >= 3 and sys.version_info[1] >= 4 else 3 - return protocol - - -@warn_after("pickle.dumps", WARNING_TIMEOUT) -def dumps(*args, **kwargs): - return pickle.dumps(*args, protocol=get_suitable_pickle_protocol(), **kwargs) - - -@warn_after("pickle.dump", WARNING_TIMEOUT) -def dump(*args, **kwargs): - return pickle.dump(*args, protocol=get_suitable_pickle_protocol(), **kwargs) - - -@warn_after("pickle.loads", WARNING_TIMEOUT) -def loads(*args, **kwargs): - assert ( - "custom_main_path" not in kwargs - ), "loads does not implement support for the argument custom_main_path" - return pickle.loads(*args, **kwargs) - - -class RenameUnpickler(pickle.Unpickler): - def find_class(self, module, name): - renamed_module = module - if module == "__main__" and self.custom_main_path is not None: - renamed_module = self.custom_main_path - - return super(RenameUnpickler, self).find_class(renamed_module, name) - - -@warn_after("pickle.load", WARNING_TIMEOUT) -def load(f, custom_main_path=None): - unpickler = RenameUnpickler(f) - unpickler.custom_main_path = ( # pylint: disable=attribute-defined-outside-init - custom_main_path - ) - return unpickler.load() diff --git a/cluster_tools/cluster_tools/remote.py b/cluster_tools/cluster_tools/remote.py index ac2912196..5d310301d 100644 --- a/cluster_tools/cluster_tools/remote.py +++ b/cluster_tools/cluster_tools/remote.py @@ -3,30 +3,34 @@ import os import sys import traceback +from typing import Any, Dict, Optional, Type +from cluster_tools._utils import pickling +from cluster_tools._utils.string import with_preliminary_postfix +from cluster_tools.schedulers.cluster_executor import ClusterExecutor from cluster_tools.schedulers.kube import KubernetesExecutor from cluster_tools.schedulers.pbs import PBSExecutor from cluster_tools.schedulers.slurm import SlurmExecutor -from cluster_tools.util import with_preliminary_postfix -from . import pickling - -def get_executor_class(executor_key): +def get_executor_class(executor_key: str) -> Type[ClusterExecutor]: return { "slurm": SlurmExecutor, "pbs": PBSExecutor, "kubernetes": KubernetesExecutor, - }.get(executor_key) + }[executor_key] -def format_remote_exc(): +def format_remote_exc() -> str: typ, value, tb = sys.exc_info() - tb = tb.tb_next # Remove root call to worker(). + if tb is not None: + tb = tb.tb_next # Remove root call to worker(). return "".join(traceback.format_exception(typ, value, tb)) -def get_custom_main_path(workerid, executor): +def get_custom_main_path( + workerid: str, executor: Type[ClusterExecutor] +) -> Optional[str]: custom_main_path = None main_meta_path = executor.get_main_meta_path(cfut_dir, workerid) if os.path.exists(main_meta_path): @@ -35,12 +39,18 @@ def get_custom_main_path(workerid, executor): return custom_main_path -def worker(executor, workerid, job_array_index, job_array_index_offset, cfut_dir): +def worker( + executor: Type[ClusterExecutor], + workerid: str, + job_array_index: Optional[int], + job_array_index_offset: str, + cfut_dir: str, +) -> None: """Called to execute a job on a remote host.""" if job_array_index is not None: workerid_with_idx = ( - worker_id + "_" + str(int(job_array_index_offset) + int(job_array_index)) + worker_id + "_" + str(int(job_array_index_offset) + job_array_index) ) else: workerid_with_idx = worker_id @@ -97,7 +107,9 @@ def worker(executor, workerid, job_array_index, job_array_index_offset, cfut_dir logging.debug("Pickle file renamed to {}.".format(destfile)) -def setup_logging(meta_data, executor, cfut_dir): +def setup_logging( + meta_data: Dict[str, Any], executor: Type[ClusterExecutor], cfut_dir: str +) -> None: if "logging_setup_fn" in meta_data: logging.debug("Using supplied logging_setup_fn to setup logging.") job_id_string = executor.get_job_id_string() diff --git a/cluster_tools/cluster_tools/schedulers/cluster_executor.py b/cluster_tools/cluster_tools/schedulers/cluster_executor.py index a60384246..be5f6f535 100644 --- a/cluster_tools/cluster_tools/schedulers/cluster_executor.py +++ b/cluster_tools/cluster_tools/schedulers/cluster_executor.py @@ -6,23 +6,42 @@ import time from abc import abstractmethod from concurrent import futures +from concurrent.futures import Future from functools import partial -from typing import List, Optional, Tuple, Type +from typing import ( + Any, + Callable, + Dict, + Iterable, + Iterator, + List, + Optional, + Tuple, + Type, + TypeVar, + Union, + cast, +) -from typing_extensions import Literal +from typing_extensions import Literal, ParamSpec -from cluster_tools import pickling -from cluster_tools.pickling import file_path_to_absolute_module -from cluster_tools.tailf import Tail -from cluster_tools.util import ( - FileWaitThread, - enrich_future_with_uncaught_warning, +from cluster_tools._utils import pickling +from cluster_tools._utils.file_wait_thread import FileWaitThread +from cluster_tools._utils.reflection import ( + file_path_to_absolute_module, get_function_name, - random_string, - with_preliminary_postfix, ) +from cluster_tools._utils.string import random_string, with_preliminary_postfix +from cluster_tools._utils.tailf import Tail +from cluster_tools._utils.warning import enrich_future_with_uncaught_warning +from cluster_tools.executors.multiprocessing import CFutDict + +NOT_YET_SUBMITTED_STATE_TYPE = Literal["NOT_YET_SUBMITTED"] +NOT_YET_SUBMITTED_STATE: NOT_YET_SUBMITTED_STATE_TYPE = "NOT_YET_SUBMITTED" -NOT_YET_SUBMITTED_STATE = "NOT_YET_SUBMITTED" +_T = TypeVar("_T") +_P = ParamSpec("_P") +_S = TypeVar("_S") def join_messages(strings: List[str]) -> str: @@ -30,16 +49,18 @@ def join_messages(strings: List[str]) -> str: class RemoteException(Exception): - def __init__(self, error, job_id): # pylint: disable=super-init-not-called + def __init__( + self, error: str, job_id: str + ): # pylint: disable=super-init-not-called self.error = error self.job_id = job_id - def __str__(self): + def __str__(self) -> str: return self.error.strip() + f" (job_id={self.job_id})" class RemoteOutOfMemoryException(RemoteException): - def __str__(self): + def __str__(self) -> str: return str(self.job_id) + "\n" + self.error.strip() @@ -48,13 +69,13 @@ class ClusterExecutor(futures.Executor): def __init__( self, - debug=False, - keep_logs=True, - cfut_dir=None, - job_resources=None, - job_name=None, - additional_setup_lines=None, - **kwargs, + debug: bool = False, + keep_logs: bool = True, + cfut_dir: Optional[str] = None, + job_resources: Optional[Dict[str, Any]] = None, + job_name: Optional[str] = None, + additional_setup_lines: Optional[List[str]] = None, + **kwargs: Any, ): """ `kwargs` can be the following optional parameters: @@ -74,7 +95,7 @@ def __init__( self.cfut_dir = ( cfut_dir if cfut_dir is not None else os.getenv("CFUT_DIR", ".cfut") ) - self.files_to_clean_up = [] + self.files_to_clean_up: List[str] = [] logging.info( f"Instantiating ClusterExecutor. Log files are stored in {self.cfut_dir}" @@ -83,8 +104,10 @@ def __init__( # `jobs` maps from job id to (future, workerid, outfile_name, should_keep_output) # In case, job arrays are used: job id and workerid are in the format of # `job_id-job_index` and `workerid-job_index`. - self.jobs = {} - self.job_outfiles = {} + self.jobs: Dict[ + str, + Union[NOT_YET_SUBMITTED_STATE_TYPE, Tuple[Future, str, str, bool]], + ] = {} self.jobs_lock = threading.Lock() self.jobs_empty_cond = threading.Condition(self.jobs_lock) self.keep_logs = keep_logs @@ -113,10 +136,12 @@ def __init__( @classmethod @abstractmethod - def executor_key(cls): + def executor_key(cls) -> str: pass - def handle_kill(self, existing_sigint_handler, signum, frame): + def handle_kill( + self, existing_sigint_handler: Any, signum: Optional[int], frame: Any + ) -> None: if self.is_shutting_down: return @@ -133,17 +158,22 @@ def handle_kill(self, existing_sigint_handler, signum, frame): existing_sigint_handler(signum, frame) @abstractmethod - def inner_handle_kill(self, _signum, _frame): + def inner_handle_kill(self, _signum: Any, _frame: Any) -> None: pass @abstractmethod def check_job_state( - self, job_id_with_index + self, job_id_with_index: str ) -> Literal["failed", "ignore", "completed"]: pass + @staticmethod + @abstractmethod + def get_current_job_id() -> str: + pass + def investigate_failed_job( - self, job_id_with_index # pylint: disable=unused-argument + self, job_id_with_index: str # pylint: disable=unused-argument ) -> Optional[Tuple[str, Type[RemoteException]]]: """ When a job fails, this method is called to investigate why. If a tuple is returned, @@ -154,7 +184,12 @@ def investigate_failed_job( """ return None - def _start(self, workerid, job_count=None, job_name=None): + def _start( + self, + workerid: str, + job_count: Optional[int] = None, + job_name: Optional[str] = None, + ) -> Tuple[List["Future[str]"], List[Tuple[int, int]]]: """Start job(s) with the given worker ID and return IDs identifying the new job(s). The job should run ``python -m cfut.remote . @@ -183,10 +218,10 @@ def inner_submit( job_name: Optional[str] = None, additional_setup_lines: Optional[List[str]] = None, job_count: Optional[int] = None, - ) -> Tuple[List["futures.Future[str]"], List[Tuple[int, int]]]: + ) -> Tuple[List["Future[str]"], List[Tuple[int, int]]]: pass - def _cleanup(self, jobid): + def _cleanup(self, jobid: str) -> None: """Given a job ID as returned by _start, perform any necessary cleanup after the job has finished. """ @@ -198,44 +233,53 @@ def _cleanup(self, jobid): @staticmethod @abstractmethod - def format_log_file_name(job_id_with_index, suffix=".stdout"): + def format_log_file_name(job_id_with_index: str, suffix: str = ".stdout") -> str: pass @classmethod - def format_log_file_path(cls, cfut_dir, job_id_with_index, suffix=".stdout"): + def format_log_file_path( + cls, cfut_dir: str, job_id_with_index: str, suffix: str = ".stdout" + ) -> str: return os.path.join( cfut_dir, cls.format_log_file_name(job_id_with_index, suffix) ) @classmethod @abstractmethod - def get_job_id_string(self): + def get_job_id_string(self) -> str: + pass + + @staticmethod + @abstractmethod + def get_job_array_index() -> Optional[int]: pass @staticmethod - def get_temp_file_path(cfut_dir, file_name): + def get_temp_file_path(cfut_dir: str, file_name: str) -> str: return os.path.join(cfut_dir, file_name) @staticmethod - def format_infile_name(cfut_dir, job_id): + def format_infile_name(cfut_dir: str, job_id: str) -> str: return os.path.join(cfut_dir, "cfut.in.%s.pickle" % job_id) @staticmethod - def format_outfile_name(cfut_dir, job_id): + def format_outfile_name(cfut_dir: str, job_id: str) -> str: return os.path.join(cfut_dir, "cfut.out.%s.pickle" % job_id) - def get_python_executable(self): + def get_python_executable(self) -> str: return sys.executable - def _completion(self, jobid, failed_early): + def _completion(self, jobid: str, failed_early: bool) -> None: """Called whenever a job finishes.""" with self.jobs_lock: job_info = self.jobs.pop(jobid) + assert job_info != NOT_YET_SUBMITTED_STATE + if len(job_info) == 4: fut, workerid, outfile_name, should_keep_output = job_info else: # Backwards compatibility - fut, workerid = job_info + fut, workerid = job_info # type: ignore[misc] should_keep_output = False outfile_name = self.format_outfile_name(self.cfut_dir, workerid) @@ -300,18 +344,23 @@ def _completion(self, jobid, failed_early): self._cleanup(jobid) - def ensure_not_shutdown(self): + def ensure_not_shutdown(self) -> None: if self.was_requested_to_shutdown: raise RuntimeError( "submit() was invoked on a ClusterExecutor instance even though shutdown() was executed for that instance." ) - def create_enriched_future(self): - fut = futures.Future() + def create_enriched_future(self) -> Future: + fut: Future = Future() enrich_future_with_uncaught_warning(fut) return fut - def submit(self, fun, *args, **kwargs): + def submit( # type: ignore[override] + self, + __fn: Callable[_P, _T], + *args: _P.args, + **kwargs: _P.kwargs, + ) -> "Future[_T]": """ Submit a job to the pool. kwargs may contain __cfut_options which currently should look like: @@ -324,19 +373,21 @@ def submit(self, fun, *args, **kwargs): fut = self.create_enriched_future() workerid = random_string() - should_keep_output = False if "__cfut_options" in kwargs: should_keep_output = True - output_pickle_path = kwargs["__cfut_options"]["output_pickle_path"] + output_pickle_path = cast(CFutDict, kwargs["__cfut_options"])[ + "output_pickle_path" + ] del kwargs["__cfut_options"] else: + should_keep_output = False output_pickle_path = self.format_outfile_name(self.cfut_dir, workerid) self.ensure_not_shutdown() # Start the job. serialized_function_info = pickling.dumps( - (fun, args, kwargs, self.meta_data, output_pickle_path) + (__fn, args, kwargs, self.meta_data, output_pickle_path) ) with open(self.format_infile_name(self.cfut_dir, workerid), "wb") as f: f.write(serialized_function_info) @@ -350,7 +401,7 @@ def submit(self, fun, *args, **kwargs): ) os.unlink(preliminary_output_pickle_path) - job_name = get_function_name(fun) + job_name = get_function_name(__fn) jobids_futures, _ = self._start(workerid, job_name=job_name) # Only a single job was submitted jobid = jobids_futures[0].result() @@ -364,34 +415,39 @@ def submit(self, fun, *args, **kwargs): with self.jobs_lock: self.jobs[jobid] = (fut, workerid, output_pickle_path, should_keep_output) - fut.cluster_jobid = jobid + fut.cluster_jobid = jobid # type: ignore[attr-defined] return fut @classmethod - def get_workerid_with_index(cls, workerid, index): - return workerid + "_" + str(index) + def get_workerid_with_index(cls, workerid: str, index: Union[int, str]) -> str: + return f"{workerid}_{index}" @classmethod - def get_jobid_with_index(cls, jobid, index): - return str(jobid) + "_" + str(index) + def get_jobid_with_index(cls, jobid: Union[str, int], index: int) -> str: + return f"{jobid}_{index}" - def get_function_pickle_path(self, workerid): + def get_function_pickle_path(self, workerid: str) -> str: return self.format_infile_name( self.cfut_dir, self.get_workerid_with_index(workerid, "function") ) @staticmethod - def get_main_meta_path(cfut_dir, workerid): + def get_main_meta_path(cfut_dir: str, workerid: str) -> str: return os.path.join(cfut_dir, f"cfut.main_path.{workerid}.txt") - def store_main_path_to_meta_file(self, workerid): + def store_main_path_to_meta_file(self, workerid: str) -> None: with open(self.get_main_meta_path(self.cfut_dir, workerid), "w") as file: file.write(file_path_to_absolute_module(sys.argv[0])) - def map_to_futures(self, fun, allArgs, output_pickle_path_getter=None): + def map_to_futures( + self, + fn: Callable[[_S], _T], + args: Iterable[_S], # TODO change: allow more than one arg per call + output_pickle_path_getter: Optional[Callable[[_S], os.PathLike]] = None, + ) -> List["Future[_T]"]: self.ensure_not_shutdown() - allArgs = list(allArgs) - if len(allArgs) == 0: + args = list(args) + if len(args) == 0: return [] should_keep_output = output_pickle_path_getter is not None @@ -402,10 +458,10 @@ def map_to_futures(self, fun, allArgs, output_pickle_path_getter=None): pickled_function_path = self.get_function_pickle_path(workerid) self.files_to_clean_up.append(pickled_function_path) with open(pickled_function_path, "wb") as file: - pickling.dump(fun, file) + pickling.dump(fn, file) self.store_main_path_to_meta_file(workerid) - for index, arg in enumerate(allArgs): + for index, arg in enumerate(args): fut = self.create_enriched_future() workerid_with_index = self.get_workerid_with_index(workerid, index) @@ -414,7 +470,7 @@ def map_to_futures(self, fun, allArgs, output_pickle_path_getter=None): self.cfut_dir, workerid_with_index ) else: - output_pickle_path = output_pickle_path_getter(arg) + output_pickle_path = str(output_pickle_path_getter(arg)) preliminary_output_pickle_path = with_preliminary_postfix( output_pickle_path @@ -445,8 +501,8 @@ def map_to_futures(self, fun, allArgs, output_pickle_path_getter=None): # not even submitted yet. self.jobs[workerid_with_index] = NOT_YET_SUBMITTED_STATE - job_count = len(allArgs) - job_name = get_function_name(fun) + job_count = len(args) + job_name = get_function_name(fn) jobids_futures, job_index_ranges = self._start(workerid, job_count, job_name) number_of_batches = len(jobids_futures) @@ -468,13 +524,13 @@ def map_to_futures(self, fun, allArgs, output_pickle_path_getter=None): def register_jobs( self, - futs_with_output_paths, - workerid, - should_keep_output, - job_index_offset, - batch_description, - jobid_future, - ): + futs_with_output_paths: List[Tuple[Future, str]], + workerid: str, + should_keep_output: bool, + job_index_offset: int, + batch_description: str, + jobid_future: "Future[str]", + ) -> None: jobid = jobid_future.result() if self.debug: @@ -492,8 +548,9 @@ def register_jobs( with_preliminary_postfix(output_path), jobid_with_index ) - fut.cluster_jobid = jobid - fut.cluster_jobindex = array_index + fut.cluster_jobid = jobid # type: ignore[attr-defined] + # fut.cluster_jobindex is only used for debugging: + fut.cluster_jobindex = array_index # type: ignore[attr-defined] job_index = job_index_offset + array_index workerid_with_index = self.get_workerid_with_index(workerid, job_index) @@ -507,8 +564,12 @@ def register_jobs( should_keep_output, ) - def shutdown(self, wait=True): + def shutdown(self, wait: bool = True, cancel_futures: bool = True) -> None: """Close the pool.""" + if not cancel_futures: + logging.warning( + "The provided cancel_futures argument is ignored by ClusterExecutor." + ) self.was_requested_to_shutdown = True if wait: with self.jobs_lock: @@ -525,22 +586,29 @@ def shutdown(self, wait=True): pass self.files_to_clean_up = [] - def map(self, func, args, timeout=None, chunksize=None): + # TODO: args should be *iterables, this would be a breaking change! + def map( # type: ignore[override] + self, + fn: Callable[_P, _T], + args: Iterable[Any], + timeout: Optional[float] = None, + chunksize: Optional[int] = None, + ) -> Iterator[_T]: if chunksize is not None: logging.warning( - "The provided chunksize parameter is ignored by ClusterExecutor." + "The provided chunksize argument is ignored by ClusterExecutor." ) start_time = time.time() - futs = self.map_to_futures(func, args) + futs = self.map_to_futures(fn, args) # Return a separate generator as an iterator to avoid that the # map() method itself becomes a generator (due to the usage of yield). # If map() was a generator, the submit() calls would be invoked # lazily which can lead to a shutdown of the executor before # the submit calls are performed. - def result_generator(): + def result_generator() -> Iterator[_T]: for fut in futs: passed_time = time.time() - start_time remaining_timeout = None if timeout is None else timeout - passed_time @@ -548,27 +616,27 @@ def result_generator(): return result_generator() - def map_unordered(self, func, args): - futs = self.map_to_futures(func, args) + def map_unordered(self, fn: Callable[_P, _T], args: Any) -> Iterator[_T]: + futs = self.map_to_futures(fn, args) # Return a separate generator to avoid that map_unordered # is executed lazily. - def result_generator(): + def result_generator() -> Iterator[_T]: for fut in futures.as_completed(futs): yield fut.result() return result_generator() - def forward_log(self, fut): + def forward_log(self, fut: "Future[_T]") -> _T: """ Takes a future from which the log file is forwarded to the active process. This method blocks as long as the future is not done. """ - log_path = self.format_log_file_path(self.cfut_dir, fut.cluster_jobid) + log_path = self.format_log_file_path(self.cfut_dir, fut.cluster_jobid) # type: ignore[attr-defined] # Don't use a logger instance here, since the child process # probably already used a logger. - log_callback = lambda s: sys.stdout.write(f"(jid={fut.cluster_jobid}) {s}") + log_callback = lambda s: sys.stdout.write(f"(jid={fut.cluster_jobid}) {s}") # type: ignore[attr-defined] tailer = Tail(log_path, log_callback) fut.add_done_callback(lambda _: tailer.cancel()) @@ -582,5 +650,5 @@ def forward_log(self, fut): return fut.result() @abstractmethod - def get_pending_tasks(self): + def get_pending_tasks(self) -> Iterable[str]: pass diff --git a/cluster_tools/cluster_tools/schedulers/kube.py b/cluster_tools/cluster_tools/schedulers/kube.py index 7cb613c69..aa1d229dc 100644 --- a/cluster_tools/cluster_tools/schedulers/kube.py +++ b/cluster_tools/cluster_tools/schedulers/kube.py @@ -1,24 +1,24 @@ """Abstracts access to a Kubernetes cluster via its Python library.""" -import concurrent import os import re import sys +from concurrent.futures import Future from pathlib import Path -from typing import List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple from uuid import uuid4 import kubernetes import kubernetes.client.models as kubernetes_models from typing_extensions import Literal -from .cluster_executor import ClusterExecutor +from cluster_tools.schedulers.cluster_executor import ClusterExecutor -def volume_name_from_path(path: Path) -> str: +def _volume_name_from_path(path: Path) -> str: return f"{(hash(str(path)) & sys.maxsize):016x}" -def deduplicate_mounts(mounts: List[Path]) -> List[Path]: +def _deduplicate_mounts(mounts: List[Path]) -> List[Path]: output = [] unique_mounts = set(mounts) for mount in unique_mounts: @@ -28,15 +28,34 @@ def deduplicate_mounts(mounts: List[Path]) -> List[Path]: class KubernetesClient: - def __init__(self): + def __init__(self) -> None: kubernetes.config.load_kube_config() self.core = kubernetes.client.api.core_v1_api.CoreV1Api() self.batch = kubernetes.client.api.batch_v1_api.BatchV1Api() class KubernetesExecutor(ClusterExecutor): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + job_resources: Dict[str, Any] + + def __init__( + self, + debug: bool = False, + keep_logs: bool = True, + cfut_dir: Optional[str] = None, + job_resources: Optional[Dict[str, Any]] = None, + job_name: Optional[str] = None, + additional_setup_lines: Optional[List[str]] = None, + **kwargs: Any, + ): + super().__init__( + debug=debug, + keep_logs=keep_logs, + cfut_dir=cfut_dir, + job_resources=job_resources, + job_name=job_name, + additional_setup_lines=additional_setup_lines, + **kwargs, + ) if self.job_resources is None: self.job_resources = {} if "namespace" not in self.job_resources: @@ -53,7 +72,7 @@ def executor_key(cls) -> str: return "kubernetes" @staticmethod - def format_log_file_name(job_id_with_index: str, suffix=".stdout") -> str: + def format_log_file_name(job_id_with_index: str, suffix: str = ".stdout") -> str: return "kube.{}.log{}".format(str(job_id_with_index), suffix) @staticmethod @@ -70,18 +89,20 @@ def get_job_array_index() -> Optional[int]: return None @staticmethod - def get_current_job_id() -> Optional[str]: - return os.environ.get("JOB_ID", None) + def get_current_job_id() -> str: + r = os.environ.get("JOB_ID") + assert r is not None + return r @classmethod - def get_job_id_string(cls) -> Optional[str]: + def get_job_id_string(cls) -> str: job_id = cls.get_current_job_id() job_index = cls.get_job_array_index() if job_index is None: return job_id return cls.get_jobid_with_index(job_id, job_index) - def inner_handle_kill(self, *args, **kwargs): + def inner_handle_kill(self, *args: Any, **kwargs: Any) -> None: job_ids = ",".join(str(job_id) for job_id in self.jobs.keys()) print( @@ -90,7 +111,7 @@ def inner_handle_kill(self, *args, **kwargs): ) ) - def ensure_kubernetes_namespace(self): + def ensure_kubernetes_namespace(self) -> None: kubernetes_client = KubernetesClient() try: kubernetes_client.core.read_namespace(self.job_resources["namespace"]) @@ -107,7 +128,7 @@ def ensure_kubernetes_namespace(self): ) ) - def get_python_executable(self): + def get_python_executable(self) -> str: return self.job_resources.get("python_executable", "python") def inner_submit( @@ -116,14 +137,14 @@ def inner_submit( job_name: Optional[str] = None, additional_setup_lines: Optional[List[str]] = None, job_count: Optional[int] = None, - ) -> Tuple[List["concurrent.futures.Future[str]"], List[Tuple[int, int]]]: + ) -> Tuple[List["Future[str]"], List[Tuple[int, int]]]: """Starts a Kubernetes pod that runs the specified shell command line.""" kubernetes_client = KubernetesClient() self.ensure_kubernetes_namespace() job_id = str(uuid4()) - job_id_future: "concurrent.futures.Future[str]" = concurrent.futures.Future() + job_id_future: "Future[str]" = Future() job_id_future.set_result(job_id) job_id_futures = [job_id_future] @@ -146,7 +167,7 @@ def inner_submit( if is_array_job else self.format_log_file_path(self.cfut_dir, job_id) ) - mounts = deduplicate_mounts( + mounts = _deduplicate_mounts( [Path(mount) for mount in self.job_resources["mounts"]] + [Path.cwd(), Path(self.cfut_dir).absolute()] ) @@ -207,7 +228,7 @@ def inner_submit( ), volume_mounts=[ kubernetes_models.V1VolumeMount( - name=volume_name_from_path(mount), + name=_volume_name_from_path(mount), mount_path=str(mount), ) for mount in mounts @@ -218,7 +239,7 @@ def inner_submit( restart_policy="Never", volumes=[ kubernetes_models.V1Volume( - name=volume_name_from_path(mount), + name=_volume_name_from_path(mount), host_path=kubernetes_models.V1HostPathVolumeSource( path=str(mount) ), diff --git a/cluster_tools/cluster_tools/schedulers/pbs.py b/cluster_tools/cluster_tools/schedulers/pbs.py index bdaaab392..ace527868 100644 --- a/cluster_tools/cluster_tools/schedulers/pbs.py +++ b/cluster_tools/cluster_tools/schedulers/pbs.py @@ -3,14 +3,14 @@ import logging import os import re -from concurrent import futures -from typing import Dict, List, Optional, Tuple, Union +from concurrent.futures import Future +from typing import Any, Dict, List, Optional, Tuple, Union from typing_extensions import Literal -from cluster_tools.util import call, chcall, random_string - -from .cluster_executor import ClusterExecutor +from cluster_tools._utils.call import call, chcall +from cluster_tools._utils.string import random_string +from cluster_tools.schedulers.cluster_executor import ClusterExecutor # qstat vs. checkjob PBS_STATES: Dict[str, List[str]] = { @@ -38,22 +38,27 @@ def executor_key(cls) -> str: return "pbs" @staticmethod - def get_job_array_index(): - return os.environ.get("PBS_ARRAYID", None) + def get_job_array_index() -> Optional[int]: + try: + return int(os.environ["PBS_ARRAYID"]) + except KeyError: + return None @staticmethod - def get_current_job_id(): - return os.environ.get("PBS_JOBID") + def get_current_job_id() -> str: + r = os.environ.get("PBS_JOBID") + assert r is not None + return r @staticmethod - def format_log_file_name(job_id_with_index, suffix=".stdout"): - return "pbs.{}.log{}".format(str(job_id_with_index), suffix) + def format_log_file_name(job_id_with_index: str, suffix: str = ".stdout") -> str: + return f"pbs.{job_id_with_index}.log{suffix}" @classmethod - def get_job_id_string(cls): + def get_job_id_string(cls) -> str: return cls.get_current_job_id() - def inner_handle_kill(self, *args, **kwargs): + def inner_handle_kill(self, *args: Any, **kwargs: Any) -> None: scheduled_job_ids: List[Union[int, str]] = list(self.jobs.keys()) if len(scheduled_job_ids): @@ -78,7 +83,7 @@ def inner_handle_kill(self, *args, **kwargs): f"Couldn't automatically cancel all PBS jobs. Reason: {stderr}" ) - def submit_text(self, job): + def submit_text(self, job: str) -> str: """Submits a PBS job represented as a job file string. Returns the job ID. """ @@ -89,13 +94,13 @@ def submit_text(self, job): with open(filename, "w") as f: f.write(job) jobid_desc, _ = chcall("qsub -V {}".format(filename)) - match = re.search("^[0-9]+", jobid_desc.decode("utf-8")) + match = re.search("^[0-9]+", jobid_desc) assert match is not None jobid = match.group(0) print("jobid", jobid) # os.unlink(filename) - return int(jobid) + return str(int(jobid)) # int() ensures coherent parsing def inner_submit( self, @@ -103,7 +108,7 @@ def inner_submit( job_name: Optional[str] = None, additional_setup_lines: Optional[List[str]] = None, job_count: Optional[int] = None, - ) -> Tuple[List["futures.Future[str]"], List[Tuple[int, int]]]: + ) -> Tuple[List["Future[str]"], List[Tuple[int, int]]]: """Starts a PBS job that runs the specified shell command line.""" if additional_setup_lines is None: additional_setup_lines = [] @@ -148,13 +153,13 @@ def inner_submit( ] job_id = self.submit_text("\n".join(script_lines)) - job_id_future: "futures.Future[str]" = futures.Future() + job_id_future: "Future[str]" = Future() job_id_future.set_result(job_id) return [job_id_future], [(0, job_count or 1)] def check_job_state( - self, job_id_with_index + self, job_id_with_index: str ) -> Literal["failed", "ignore", "completed"]: if len(str(job_id_with_index).split("_")) >= 2: a, b = job_id_with_index.split("_") @@ -173,7 +178,7 @@ def check_job_state( return "ignore" else: - job_state_search = re.search("job_state = ([a-zA-Z_]*)", str(stdout)) + job_state_search = re.search("job_state = ([a-zA-Z_]*)", stdout) if job_state_search: job_state = job_state_search.group(1) @@ -199,6 +204,6 @@ def check_job_state( ) return "ignore" - def get_pending_tasks(self): + def get_pending_tasks(self) -> List: # Not implemented, yet. Currently, this is only used for performance optimization. return [] diff --git a/cluster_tools/cluster_tools/schedulers/slurm.py b/cluster_tools/cluster_tools/schedulers/slurm.py index d87042161..fa5ea7e0d 100644 --- a/cluster_tools/cluster_tools/schedulers/slurm.py +++ b/cluster_tools/cluster_tools/schedulers/slurm.py @@ -1,19 +1,31 @@ """Abstracts access to a Slurm cluster via its command-line tools. """ -import concurrent import logging import os import re import sys import threading +from concurrent.futures import Future from functools import lru_cache -from typing import List, Optional, Tuple, Type, Union +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Optional, + Tuple, + Type, + TypeVar, + Union, + cast, +) from typing_extensions import Literal -from cluster_tools.util import call, chcall, random_string - -from .cluster_executor import ( +from cluster_tools._utils.call import call, chcall +from cluster_tools._utils.string import random_string +from cluster_tools.schedulers.cluster_executor import ( NOT_YET_SUBMITTED_STATE, ClusterExecutor, RemoteException, @@ -49,41 +61,69 @@ SLURM_QUEUE_CHECK_INTERVAL = 1 if "pytest" in sys.modules else 60 +T = TypeVar("T") -def noopDecorator(func): + +def noopDecorator(func: T) -> T: return func -cache_in_production = noopDecorator if "pytest" in sys.modules else lru_cache(maxsize=1) +cache_in_production = cast( + Callable[[T], T], noopDecorator if "pytest" in sys.modules else lru_cache(maxsize=1) +) class SlurmExecutor(ClusterExecutor): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.submit_threads = [] + def __init__( + self, + debug: bool = False, + keep_logs: bool = True, + cfut_dir: Optional[str] = None, + job_resources: Optional[Dict[str, Any]] = None, + job_name: Optional[str] = None, + additional_setup_lines: Optional[List[str]] = None, + **kwargs: Any, + ): + super().__init__( + debug=debug, + keep_logs=keep_logs, + cfut_dir=cfut_dir, + job_resources=job_resources, + job_name=job_name, + additional_setup_lines=additional_setup_lines, + **kwargs, + ) + self.submit_threads: List["_JobSubmitThread"] = [] @classmethod def executor_key(cls) -> str: return "slurm" @staticmethod - def get_job_array_index(): - return os.environ.get("SLURM_ARRAY_TASK_ID", None) + def get_job_array_index() -> Optional[int]: + try: + return int(os.environ["SLURM_ARRAY_TASK_ID"]) + except KeyError: + return None @staticmethod - def get_job_array_id(): - return os.environ.get("SLURM_ARRAY_JOB_ID", None) + def get_job_array_id() -> str: + r = os.environ.get("SLURM_ARRAY_JOB_ID") + assert r is not None + return r @staticmethod - def get_current_job_id(): - return os.environ.get("SLURM_JOB_ID") + def get_current_job_id() -> str: + r = os.environ.get("SLURM_JOB_ID") + assert r is not None + return r @staticmethod - def format_log_file_name(job_id_with_index, suffix=".stdout"): - return "slurmpy.{}.log{}".format(str(job_id_with_index), suffix) + def format_log_file_name(job_id_with_index: str, suffix: str = ".stdout") -> str: + return f"slurmpy.{job_id_with_index}.log{suffix}" @classmethod - def get_job_id_string(cls): + def get_job_id_string(cls) -> str: job_id = cls.get_current_job_id() job_array_id = cls.get_job_array_id() job_array_index = cls.get_job_array_index() @@ -97,7 +137,7 @@ def get_job_id_string(cls): @staticmethod @cache_in_production - def get_max_array_size(): + def get_max_array_size() -> int: max_array_size_env = os.environ.get("SLURM_MAX_ARRAY_SIZE", None) if max_array_size_env is not None: logging.debug( @@ -111,7 +151,7 @@ def get_max_array_size(): "scontrol show config | sed -n '/^MaxArraySize/s/.*= *//p'" ) if exit_code == 0: - max_array_size = int(stdout.decode("utf8")) + max_array_size = int(stdout) logging.debug(f"Slurm MaxArraySize is {max_array_size}.") else: logging.warning( @@ -121,7 +161,7 @@ def get_max_array_size(): @staticmethod @cache_in_production - def get_max_running_size(): + def get_max_running_size() -> int: max_running_size_env = os.environ.get("SLURM_MAX_RUNNING_SIZE", None) if max_running_size_env is not None: logging.debug( @@ -133,7 +173,7 @@ def get_max_running_size(): @staticmethod @cache_in_production - def get_max_submit_jobs(): + def get_max_submit_jobs() -> int: max_submit_jobs_env = os.environ.get("SLURM_MAX_SUBMIT_JOBS", None) if max_submit_jobs_env is not None: logging.debug( @@ -147,14 +187,14 @@ def get_max_submit_jobs(): "sacctmgr list -n user $USER withassoc format=maxsubmitjobsperuser" ) try: - max_submit_jobs = int(stdout_user.decode("utf8")) + max_submit_jobs = int(stdout_user) except ValueError: # If there is no limit per user check whether there is a general limit stdout_qos, stderr_qos, _ = call( "sacctmgr list -n qos normal format=maxsubmitjobsperuser" ) try: - max_submit_jobs = int(stdout_qos.decode("utf8")) + max_submit_jobs = int(stdout_qos) except ValueError: logging.warning( f"Slurm's MaxSubmitJobsPerUser couldn't be determined. Reason: {stderr_user}\n{stderr_qos}" @@ -164,7 +204,7 @@ def get_max_submit_jobs(): return max_submit_jobs @staticmethod - def get_number_of_submitted_jobs(state: Optional[str] = None): + def get_number_of_submitted_jobs(state: Optional[str] = None) -> int: number_of_submitted_jobs = 0 state_string = f"-t {state}" if state else "" # --array so that each job array element is displayed on a separate line and -h to hide the header @@ -174,7 +214,7 @@ def get_number_of_submitted_jobs(state: Optional[str] = None): job_state_string = f"with state {state} " if state else "" if exit_code == 0: - number_of_submitted_jobs = int(stdout.decode("utf8")) + number_of_submitted_jobs = int(stdout) logging.debug( f"Number of currently submitted jobs {job_state_string}is {number_of_submitted_jobs}." ) @@ -185,7 +225,7 @@ def get_number_of_submitted_jobs(state: Optional[str] = None): return number_of_submitted_jobs @classmethod - def submit_text(cls, job, cfut_dir): + def submit_text(cls, job: str, cfut_dir: str) -> str: """Submits a Slurm job represented as a job file string. Returns the job ID. """ @@ -201,9 +241,9 @@ def submit_text(cls, job, cfut_dir): if len(stderr) > 0: logging.warning(f"Submitting batch job emitted warnings: {stderr}") - return int(job_id) + return str(int(job_id)) # int() ensures coherent parsing - def inner_handle_kill(self, *args, **kwargs): + def inner_handle_kill(self, *args: Any, **kwargs: Any) -> None: for submit_thread in self.submit_threads: submit_thread.stop() @@ -226,7 +266,7 @@ def inner_handle_kill(self, *args, **kwargs): ) maybe_error_or_warning = ( - f"\nErrors and warnings (if all jobs were pending 'Invalid job id' errors are expected):\n{stderr.decode('utf8')}" + f"\nErrors and warnings (if all jobs were pending 'Invalid job id' errors are expected):\n{stderr}" if stderr else "" ) @@ -234,7 +274,7 @@ def inner_handle_kill(self, *args, **kwargs): f"Canceled slurm jobs {', '.join(unique_job_ids)}.{maybe_error_or_warning}" ) - def cleanup_submit_threads(self): + def cleanup_submit_threads(self) -> None: self.submit_threads = [ thread for thread in self.submit_threads if thread.is_alive() ] @@ -245,7 +285,7 @@ def inner_submit( job_name: Optional[str] = None, additional_setup_lines: Optional[List[str]] = None, job_count: Optional[int] = None, - ) -> Tuple[List["concurrent.futures.Future[str]"], List[Tuple[int, int]]]: + ) -> Tuple[List["Future[str]"], List[Tuple[int, int]]]: """Starts a Slurm job that runs the specified shell command line.""" if additional_setup_lines is None: additional_setup_lines = [] @@ -271,7 +311,7 @@ def inner_submit( batch_size = max(min(max_array_size, max_submit_jobs), 1) scripts = [] - job_id_futures: List["concurrent.futures.Future[str]"] = [] + job_id_futures: List["Future[str]"] = [] ranges = [] number_of_jobs = job_count if job_count is not None else 1 for job_index_start in range(0, number_of_jobs, batch_size): @@ -298,7 +338,7 @@ def inner_submit( ] ) - job_id_futures.append(concurrent.futures.Future()) + job_id_futures.append(Future()) scripts.append("\n".join(script_lines)) ranges.append((job_index_start, job_index_end + 1)) @@ -306,7 +346,7 @@ def inner_submit( self.cleanup_submit_threads() - submit_thread = JobSubmitThread( + submit_thread = _JobSubmitThread( scripts, job_sizes, job_id_futures, self.cfut_dir ) self.submit_threads.append(submit_thread) @@ -315,29 +355,23 @@ def inner_submit( return job_id_futures, ranges def check_job_state( - self, job_id_with_index + self, job_id_with_index: str ) -> Literal["failed", "ignore", "completed"]: job_states = [] # If the output file was not found, we determine the job status so that # we can recognize jobs which failed hard (in this case, they don't produce output files) - stdout, _, exit_code = call("scontrol show job {}".format(job_id_with_index)) - stdout = stdout.decode("utf8") + stdout, _, exit_code = call(f"scontrol show job {job_id_with_index}") if exit_code == 0: - job_state_search = re.search("JobState=([a-zA-Z_]*)", str(stdout)) + job_state_search = re.search("JobState=([a-zA-Z_]*)", stdout) if job_state_search: job_states = [job_state_search.group(1)] else: - logging.error( - "Could not extract slurm job state? {}".format(stdout[0:10]) - ) + logging.error(f"Could not extract slurm job state? {stdout[0:10]}") else: - stdout, _, exit_code = call( - "sacct -j {} -o State -P".format(job_id_with_index) - ) - stdout = stdout.decode("utf8") + stdout, _, exit_code = call(f"sacct -j {job_id_with_index} -o State -P") if exit_code == 0: job_states = stdout.split("\n")[1:] @@ -348,7 +382,7 @@ def check_job_state( ) return "ignore" - def matches_states(slurm_states): + def matches_states(slurm_states: List[str]) -> bool: return len(list(set(job_states) & set(slurm_states))) > 0 if matches_states(SLURM_STATES["Failure"]): @@ -373,25 +407,26 @@ def matches_states(slurm_states): return "ignore" def investigate_failed_job( - self, job_id_with_index + self, job_id_with_index: str ) -> Optional[Tuple[str, Type[RemoteException]]]: # We call `seff job_id` which should return some output including a line, # such as: "Memory Efficiency: 25019.18% of 1.00 GB" - stdout, _, exit_code = call("seff {}".format(job_id_with_index)) + stdout, _, exit_code = call(f"seff {job_id_with_index}") if exit_code != 0: return None # Parse stdout into a key-value object properties = {} - stdout = stdout.decode("utf8") for line in stdout.split("\n"): if ":" not in line: continue key, value = line.split(":", 1) properties[key.strip()] = value.strip() - def investigate_memory_consumption(): + def investigate_memory_consumption() -> Optional[ + Tuple[str, Type[RemoteOutOfMemoryException]] + ]: if not properties.get("Memory Efficiency", None): return None @@ -416,7 +451,9 @@ def investigate_memory_consumption(): reason = f"The job was probably terminated because it consumed too much memory ({efficiency_note})." return (reason, RemoteOutOfMemoryException) - def investigate_exit_code(): + def investigate_exit_code() -> Optional[ + Tuple[str, Type[RemoteOutOfMemoryException]] + ]: if not properties.get("State", None): return None if "exit code 137" not in properties["State"]: @@ -436,12 +473,11 @@ def investigate_exit_code(): return investigate_exit_code() - def get_pending_tasks(self): + def get_pending_tasks(self) -> Iterable[str]: try: # Get the job ids (%i) of the active user (-u) which are pending (-t) and format # them one-per-line (-r) while excluding the header (-h). stdout, _ = chcall("squeue -u $(whoami) -t PENDING -r -h --format=%i") - stdout = stdout.decode("utf8") job_ids = set(stdout.split("\n")) return job_ids @@ -452,19 +488,25 @@ def get_pending_tasks(self): return [] -class JobSubmitThread(threading.Thread): - def __init__(self, scripts, job_sizes, futures, cfut_dir, *args, **kwargs): - super().__init__(*args, **kwargs) +class _JobSubmitThread(threading.Thread): + def __init__( + self, + scripts: List[str], + job_sizes: List[int], + futures: List["Future[str]"], + cfut_dir: str, + ): + super().__init__() self._stop_event = threading.Event() self.scripts = scripts self.job_sizes = job_sizes self.futures = futures self.cfut_dir = cfut_dir - def stop(self): + def stop(self) -> None: self._stop_event.set() - def run(self): + def run(self) -> None: max_submit_jobs = SlurmExecutor.get_max_submit_jobs() for script, job_size, future in zip(self.scripts, self.job_sizes, self.futures): diff --git a/cluster_tools/cluster_tools/util.py b/cluster_tools/cluster_tools/util.py deleted file mode 100644 index 71c77a3fc..000000000 --- a/cluster_tools/cluster_tools/util.py +++ /dev/null @@ -1,223 +0,0 @@ -import logging -import os -import random -import string -import subprocess -import threading -import time - - -def local_filename(filename=""): - return os.path.join(os.getenv("CFUT_DIR", ".cfut"), filename) - - -# Instantiate a dedicate generator to avoid being dependent on -# the global seed which some external code might have set. -random_generator = random.Random() - - -def random_string(length=32, chars=(string.ascii_letters + string.digits)): - return "".join(random_generator.choice(chars) for i in range(length)) - - -def call(command, stdin=None): - """Invokes a shell command as a subprocess, optionally with some - data sent to the standard input. Returns the standard output data, - the standard error, and the return code. - """ - if stdin is not None: - stdin_flag = subprocess.PIPE - else: - stdin_flag = None - proc = subprocess.Popen( - command, - shell=True, - stdin=stdin_flag, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - stdout, stderr = proc.communicate(stdin) - return stdout, stderr, proc.returncode - - -class CommandError(Exception): - """Raised when a shell command exits abnormally.""" - - def __init__(self, command, code, stderr): # pylint: disable=super-init-not-called - self.command = command - self.code = code - self.stderr = stderr - - def __str__(self): - return "%s exited with status %i: %s" % ( - repr(self.command), - self.code, - repr(self.stderr), - ) - - -def chcall(command, stdin=None): - """Like ``call`` but raises an exception when the return code is - nonzero. Only returns the stdout and stderr data. - """ - stdout, stderr, code = call(command, stdin) - if code != 0: - raise CommandError(command, code, stderr) - return stdout, stderr - - -def warn_after(job, seconds): - """ - Use as decorator to warn when a function is taking longer than {seconds} seconds. - """ - - def outer(fn): - def inner(*args, **kwargs): - exceeded_timeout = [False] - start_time = time.time() - - def warn_function(): - logging.warning( - "Function {} is taking suspiciously long (longer than {} seconds)".format( - job, seconds - ) - ) - exceeded_timeout[0] = True - - timer = threading.Timer(seconds, warn_function) - timer.start() - - try: - result = fn(*args, **kwargs) - if exceeded_timeout[0]: - end_time = time.time() - logging.warning( - "Function {} succeeded after all (took {} seconds)".format( - job, int(end_time - start_time) - ) - ) - finally: - timer.cancel() - return result - - return inner - - return outer - - -class FileWaitThread(threading.Thread): - """A thread that polls the filesystem waiting for a list of files to - be created. When a specified file is created, it invokes a callback. - """ - - MAX_RETRY = 30 - - def __init__(self, callback, executor, interval=2): - """The callable ``callback`` will be invoked with value - associated with the filename of each file that is created. - ``interval`` specifies the polling rate. - """ - threading.Thread.__init__(self) - self.callback = callback - self.interval = interval - self.waiting = {} - self.retryMap = {} - self.lock = threading.Lock() - self.shutdown = False - self.executor = executor - - def stop(self): - """Stop the thread soon.""" - with self.lock: - self.shutdown = True - - def waitFor(self, filename, value): - """Adds a new filename (and its associated callback value) to - the set of files being waited upon. - """ - with self.lock: - self.waiting[filename] = value - - def run(self): - def handle_completed_job(job_id, filename, failed_early): - self.callback(job_id, failed_early) - del self.waiting[filename] - - while True: - with self.lock: - if self.shutdown: - return - - pending_tasks = self.executor.get_pending_tasks() - - # Poll for each file. - for filename in list(self.waiting): - job_id = self.waiting[filename] - if job_id in pending_tasks: - # Don't check status of pending tasks, since this - # can vastly slow down the polling. - continue - - if os.path.exists(filename): - # Check for output file as a fast indicator for job completion - handle_completed_job(job_id, filename, False) - elif self.executor is not None: - status = self.executor.check_job_state(job_id) - - # We have to re-check for the output file since this could be created in the mean time - if os.path.exists(filename): - handle_completed_job(job_id, filename, False) - else: - if status == "completed": - self.retryMap[filename] = self.retryMap.get(filename, 0) - self.retryMap[filename] += 1 - - if self.retryMap[filename] <= FileWaitThread.MAX_RETRY: - # Retry by looping again - logging.warning( - "Job state is completed, but {} couldn't be found. Retrying {}/{}".format( - filename, - self.retryMap[filename], - FileWaitThread.MAX_RETRY, - ) - ) - else: - logging.error( - "Job state is completed, but {} couldn't be found.".format( - filename - ) - ) - handle_completed_job(job_id, filename, True) - - elif status == "failed": - handle_completed_job(job_id, filename, True) - elif status == "ignore": - pass - time.sleep(self.interval) - - -def get_function_name(fun): - # When using functools.partial, __name__ does not exist - try: - return fun.__name__ if hasattr(fun, "__name__") else fun.func.__name__ - except Exception: - return "" - - -def enrich_future_with_uncaught_warning(f): - def warn_on_exception(future): - maybe_exception = future.exception() - if maybe_exception is not None: - logging.error( - "A future crashed with an exception: {}. Future: {}".format( - maybe_exception, future - ) - ) - - if not hasattr(f, "is_wrapped_by_cluster_tools"): - f.is_wrapped_by_cluster_tools = True - f.add_done_callback(warn_on_exception) - - -def with_preliminary_postfix(name): - return f"{name}.preliminary" diff --git a/cluster_tools/poetry.lock b/cluster_tools/poetry.lock index d46852e98..a6ae9c749 100644 --- a/cluster_tools/poetry.lock +++ b/cluster_tools/poetry.lock @@ -58,7 +58,6 @@ python-versions = ">=3.6.2" [package.dependencies] click = ">=7.1.2" -dataclasses = {version = ">=0.6", markers = "python_version < \"3.7\""} mypy-extensions = ">=0.4.3" pathspec = ">=0.9.0,<1" platformdirs = ">=2" @@ -123,14 +122,6 @@ category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" -[[package]] -name = "dataclasses" -version = "0.8" -description = "A backport of the dataclasses module for Python 3.6" -category = "dev" -optional = false -python-versions = ">=3.6, <3.7" - [[package]] name = "executing" version = "0.8.2" @@ -259,21 +250,23 @@ python-versions = "*" [[package]] name = "mypy" -version = "0.910" +version = "0.991" description = "Optional static typing for Python" category = "dev" optional = false -python-versions = ">=3.5" +python-versions = ">=3.7" [package.dependencies] -mypy-extensions = ">=0.4.3,<0.5.0" -toml = "*" -typed-ast = {version = ">=1.4.0,<1.5.0", markers = "python_version < \"3.8\""} -typing-extensions = ">=3.7.4" +mypy-extensions = ">=0.4.3" +tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} +typed-ast = {version = ">=1.4.0,<2", markers = "python_version < \"3.8\""} +typing-extensions = ">=3.10" [package.extras] dmypy = ["psutil (>=4.0)"] -python2 = ["typed-ast (>=1.4.0,<1.5.0)"] +install-types = ["pip"] +python2 = ["typed-ast (>=1.4.0,<2)"] +reports = ["lxml"] [[package]] name = "mypy-extensions" @@ -324,8 +317,8 @@ optional = false python-versions = ">=3.6" [package.extras] -test = ["pytest-mock (>=3.6)", "pytest-cov (>=2.7)", "pytest (>=6)", "appdirs (==1.4.4)"] -docs = ["sphinx-autodoc-typehints (>=1.12)", "proselint (>=0.10.2)", "furo (>=2021.7.5b38)", "Sphinx (>=4)"] +docs = ["Sphinx (>=4)", "furo (>=2021.7.5b38)", "proselint (>=0.10.2)", "sphinx-autodoc-typehints (>=1.12)"] +test = ["appdirs (==1.4.4)", "pytest (>=6)", "pytest-cov (>=2.7)", "pytest-mock (>=3.6)"] [[package]] name = "pluggy" @@ -339,8 +332,8 @@ python-versions = ">=3.6" importlib-metadata = {version = ">=0.12", markers = "python_version < \"3.8\""} [package.extras] -testing = ["pytest-benchmark", "pytest"] -dev = ["tox", "pre-commit"] +dev = ["pre-commit", "tox"] +testing = ["pytest", "pytest-benchmark"] [[package]] name = "py" @@ -524,11 +517,11 @@ python-versions = "*" [[package]] name = "typing-extensions" -version = "4.0.1" -description = "Backported and Experimental Type Hints for Python 3.6+" -category = "dev" +version = "4.4.0" +description = "Backported and Experimental Type Hints for Python 3.7+" +category = "main" optional = false -python-versions = ">=3.6" +python-versions = ">=3.7" [[package]] name = "urllib3" @@ -552,9 +545,9 @@ optional = false python-versions = ">=3.6" [package.extras] +docs = ["Sphinx (>=3.4)", "sphinx-rtd-theme (>=0.5)"] +optional = ["python-socks", "wsaccel"] test = ["websockets"] -optional = ["wsaccel", "python-socks"] -docs = ["sphinx-rtd-theme (>=0.5)", "Sphinx (>=3.4)"] [[package]] name = "wrapt" @@ -578,8 +571,8 @@ testing = ["pytest (>=4.6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytes [metadata] lock-version = "1.1" -python-versions = "^3.6.2" -content-hash = "3870e5d3add61d10ed8f6ca61d8dd6e5e74c29e8ec196fceefd16d323112a477" +python-versions = ">=3.7" +content-hash = "b737dc0dc9c0210757c287ff6c0aceb39b6da8805f8fed3777e7d43773089281" [metadata.files] astroid = [ @@ -622,10 +615,6 @@ colorama = [ {file = "colorama-0.4.4-py2.py3-none-any.whl", hash = "sha256:9f47eda37229f68eee03b24b9748937c7dc3868f906e8ba69fbcbdd3bc5dc3e2"}, {file = "colorama-0.4.4.tar.gz", hash = "sha256:5941b2b48a20143d2267e95b1c2a7603ce057ee39fd88e7329b0c292aa16869b"}, ] -dataclasses = [ - {file = "dataclasses-0.8-py3-none-any.whl", hash = "sha256:0201d89fa866f68c8ebd9d08ee6ff50c0b255f8ec63a71c16fda7af82bb887bf"}, - {file = "dataclasses-0.8.tar.gz", hash = "sha256:8479067f342acf957dc82ec415d355ab5edb7e7646b90dc6e2fd1d96ad084c97"}, -] executing = [ {file = "executing-0.8.2-py2.py3-none-any.whl", hash = "sha256:32fc6077b103bd19e6494a72682d66d5763cf20a106d5aa7c5ccbea4e47b0df7"}, {file = "executing-0.8.2.tar.gz", hash = "sha256:c23bf42e9a7b9b212f185b1b2c3c91feb895963378887bb10e64a2e612ec0023"}, @@ -687,29 +676,36 @@ mccabe = [ {file = "mccabe-0.6.1.tar.gz", hash = "sha256:dd8d182285a0fe56bace7f45b5e7d1a6ebcbf524e8f3bd87eb0f125271b8831f"}, ] mypy = [ - {file = "mypy-0.910-cp35-cp35m-macosx_10_9_x86_64.whl", hash = "sha256:a155d80ea6cee511a3694b108c4494a39f42de11ee4e61e72bc424c490e46457"}, - {file = "mypy-0.910-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:b94e4b785e304a04ea0828759172a15add27088520dc7e49ceade7834275bedb"}, - {file = "mypy-0.910-cp35-cp35m-manylinux2010_x86_64.whl", hash = "sha256:088cd9c7904b4ad80bec811053272986611b84221835e079be5bcad029e79dd9"}, - {file = "mypy-0.910-cp35-cp35m-win_amd64.whl", hash = "sha256:adaeee09bfde366d2c13fe6093a7df5df83c9a2ba98638c7d76b010694db760e"}, - {file = "mypy-0.910-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:ecd2c3fe726758037234c93df7e98deb257fd15c24c9180dacf1ef829da5f921"}, - {file = "mypy-0.910-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:d9dd839eb0dc1bbe866a288ba3c1afc33a202015d2ad83b31e875b5905a079b6"}, - {file = "mypy-0.910-cp36-cp36m-manylinux2010_x86_64.whl", hash = "sha256:3e382b29f8e0ccf19a2df2b29a167591245df90c0b5a2542249873b5c1d78212"}, - {file = "mypy-0.910-cp36-cp36m-win_amd64.whl", hash = "sha256:53fd2eb27a8ee2892614370896956af2ff61254c275aaee4c230ae771cadd885"}, - {file = "mypy-0.910-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b6fb13123aeef4a3abbcfd7e71773ff3ff1526a7d3dc538f3929a49b42be03f0"}, - {file = "mypy-0.910-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:e4dab234478e3bd3ce83bac4193b2ecd9cf94e720ddd95ce69840273bf44f6de"}, - {file = "mypy-0.910-cp37-cp37m-manylinux2010_x86_64.whl", hash = "sha256:7df1ead20c81371ccd6091fa3e2878559b5c4d4caadaf1a484cf88d93ca06703"}, - {file = "mypy-0.910-cp37-cp37m-win_amd64.whl", hash = "sha256:0aadfb2d3935988ec3815952e44058a3100499f5be5b28c34ac9d79f002a4a9a"}, - {file = "mypy-0.910-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ec4e0cd079db280b6bdabdc807047ff3e199f334050db5cbb91ba3e959a67504"}, - {file = "mypy-0.910-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:119bed3832d961f3a880787bf621634ba042cb8dc850a7429f643508eeac97b9"}, - {file = "mypy-0.910-cp38-cp38-manylinux2010_x86_64.whl", hash = "sha256:866c41f28cee548475f146aa4d39a51cf3b6a84246969f3759cb3e9c742fc072"}, - {file = "mypy-0.910-cp38-cp38-win_amd64.whl", hash = "sha256:ceb6e0a6e27fb364fb3853389607cf7eb3a126ad335790fa1e14ed02fba50811"}, - {file = "mypy-0.910-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:1a85e280d4d217150ce8cb1a6dddffd14e753a4e0c3cf90baabb32cefa41b59e"}, - {file = "mypy-0.910-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:42c266ced41b65ed40a282c575705325fa7991af370036d3f134518336636f5b"}, - {file = "mypy-0.910-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:3c4b8ca36877fc75339253721f69603a9c7fdb5d4d5a95a1a1b899d8b86a4de2"}, - {file = "mypy-0.910-cp39-cp39-manylinux2010_x86_64.whl", hash = "sha256:c0df2d30ed496a08de5daed2a9ea807d07c21ae0ab23acf541ab88c24b26ab97"}, - {file = "mypy-0.910-cp39-cp39-win_amd64.whl", hash = "sha256:c6c2602dffb74867498f86e6129fd52a2770c48b7cd3ece77ada4fa38f94eba8"}, - {file = "mypy-0.910-py3-none-any.whl", hash = "sha256:ef565033fa5a958e62796867b1df10c40263ea9ded87164d67572834e57a174d"}, - {file = "mypy-0.910.tar.gz", hash = "sha256:704098302473cb31a218f1775a873b376b30b4c18229421e9e9dc8916fd16150"}, + {file = "mypy-0.991-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:7d17e0a9707d0772f4a7b878f04b4fd11f6f5bcb9b3813975a9b13c9332153ab"}, + {file = "mypy-0.991-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0714258640194d75677e86c786e80ccf294972cc76885d3ebbb560f11db0003d"}, + {file = "mypy-0.991-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0c8f3be99e8a8bd403caa8c03be619544bc2c77a7093685dcf308c6b109426c6"}, + {file = "mypy-0.991-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc9ec663ed6c8f15f4ae9d3c04c989b744436c16d26580eaa760ae9dd5d662eb"}, + {file = "mypy-0.991-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:4307270436fd7694b41f913eb09210faff27ea4979ecbcd849e57d2da2f65305"}, + {file = "mypy-0.991-cp310-cp310-win_amd64.whl", hash = "sha256:901c2c269c616e6cb0998b33d4adbb4a6af0ac4ce5cd078afd7bc95830e62c1c"}, + {file = "mypy-0.991-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:d13674f3fb73805ba0c45eb6c0c3053d218aa1f7abead6e446d474529aafc372"}, + {file = "mypy-0.991-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:1c8cd4fb70e8584ca1ed5805cbc7c017a3d1a29fb450621089ffed3e99d1857f"}, + {file = "mypy-0.991-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:209ee89fbb0deed518605edddd234af80506aec932ad28d73c08f1400ef80a33"}, + {file = "mypy-0.991-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:37bd02ebf9d10e05b00d71302d2c2e6ca333e6c2a8584a98c00e038db8121f05"}, + {file = "mypy-0.991-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:26efb2fcc6b67e4d5a55561f39176821d2adf88f2745ddc72751b7890f3194ad"}, + {file = "mypy-0.991-cp311-cp311-win_amd64.whl", hash = "sha256:3a700330b567114b673cf8ee7388e949f843b356a73b5ab22dd7cff4742a5297"}, + {file = "mypy-0.991-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:1f7d1a520373e2272b10796c3ff721ea1a0712288cafaa95931e66aa15798813"}, + {file = "mypy-0.991-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:641411733b127c3e0dab94c45af15fea99e4468f99ac88b39efb1ad677da5711"}, + {file = "mypy-0.991-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:3d80e36b7d7a9259b740be6d8d906221789b0d836201af4234093cae89ced0cd"}, + {file = "mypy-0.991-cp37-cp37m-win_amd64.whl", hash = "sha256:e62ebaad93be3ad1a828a11e90f0e76f15449371ffeecca4a0a0b9adc99abcef"}, + {file = "mypy-0.991-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:b86ce2c1866a748c0f6faca5232059f881cda6dda2a893b9a8373353cfe3715a"}, + {file = "mypy-0.991-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ac6e503823143464538efda0e8e356d871557ef60ccd38f8824a4257acc18d93"}, + {file = "mypy-0.991-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:0cca5adf694af539aeaa6ac633a7afe9bbd760df9d31be55ab780b77ab5ae8bf"}, + {file = "mypy-0.991-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a12c56bf73cdab116df96e4ff39610b92a348cc99a1307e1da3c3768bbb5b135"}, + {file = "mypy-0.991-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:652b651d42f155033a1967739788c436491b577b6a44e4c39fb340d0ee7f0d70"}, + {file = "mypy-0.991-cp38-cp38-win_amd64.whl", hash = "sha256:4175593dc25d9da12f7de8de873a33f9b2b8bdb4e827a7cae952e5b1a342e243"}, + {file = "mypy-0.991-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:98e781cd35c0acf33eb0295e8b9c55cdbef64fcb35f6d3aa2186f289bed6e80d"}, + {file = "mypy-0.991-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6d7464bac72a85cb3491c7e92b5b62f3dcccb8af26826257760a552a5e244aa5"}, + {file = "mypy-0.991-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c9166b3f81a10cdf9b49f2d594b21b31adadb3d5e9db9b834866c3258b695be3"}, + {file = "mypy-0.991-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b8472f736a5bfb159a5e36740847808f6f5b659960115ff29c7cecec1741c648"}, + {file = "mypy-0.991-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:5e80e758243b97b618cdf22004beb09e8a2de1af481382e4d84bc52152d1c476"}, + {file = "mypy-0.991-cp39-cp39-win_amd64.whl", hash = "sha256:74e259b5c19f70d35fcc1ad3d56499065c601dfe94ff67ae48b85596b9ec1461"}, + {file = "mypy-0.991-py3-none-any.whl", hash = "sha256:de32edc9b0a7e67c2775e574cb061a537660e51210fbf6006b0b36ea695ae9bb"}, + {file = "mypy-0.991.tar.gz", hash = "sha256:3c0165ba8f354a6d9881809ef29f1a9318a236a6d81c690094c5df32107bde06"}, ] mypy-extensions = [ {file = "mypy_extensions-0.4.3-py2.py3-none-any.whl", hash = "sha256:090fedd75945a69ae91ce1303b5824f428daf5a028d2f6ab8a299250a846f15d"}, @@ -797,6 +793,13 @@ pyyaml = [ {file = "PyYAML-6.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:f84fbc98b019fef2ee9a1cb3ce93e3187a6df0b2538a651bfb890254ba9f90b5"}, {file = "PyYAML-6.0-cp310-cp310-win32.whl", hash = "sha256:2cd5df3de48857ed0544b34e2d40e9fac445930039f3cfe4bcc592a1f836d513"}, {file = "PyYAML-6.0-cp310-cp310-win_amd64.whl", hash = "sha256:daf496c58a8c52083df09b80c860005194014c3698698d1a57cbcfa182142a3a"}, + {file = "PyYAML-6.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d4b0ba9512519522b118090257be113b9468d804b19d63c71dbcf4a48fa32358"}, + {file = "PyYAML-6.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:81957921f441d50af23654aa6c5e5eaf9b06aba7f0a19c18a538dc7ef291c5a1"}, + {file = "PyYAML-6.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:afa17f5bc4d1b10afd4466fd3a44dc0e245382deca5b3c353d8b757f9e3ecb8d"}, + {file = "PyYAML-6.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dbad0e9d368bb989f4515da330b88a057617d16b6a8245084f1b05400f24609f"}, + {file = "PyYAML-6.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:432557aa2c09802be39460360ddffd48156e30721f5e8d917f01d31694216782"}, + {file = "PyYAML-6.0-cp311-cp311-win32.whl", hash = "sha256:bfaef573a63ba8923503d27530362590ff4f576c626d86a9fed95822a8255fd7"}, + {file = "PyYAML-6.0-cp311-cp311-win_amd64.whl", hash = "sha256:01b45c0191e6d66c470b6cf1b9531a771a83c1c4208272ead47a3ae4f2f603bf"}, {file = "PyYAML-6.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:897b80890765f037df3403d22bab41627ca8811ae55e9a722fd0392850ec4d86"}, {file = "PyYAML-6.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:50602afada6d6cbfad699b0c7bb50d5ccffa7e46a3d738092afddc1f9758427f"}, {file = "PyYAML-6.0-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:48c346915c114f5fdb3ead70312bd042a953a8ce5c7106d5bfb1a5254e47da92"}, @@ -881,8 +884,8 @@ typed-ast = [ {file = "typed_ast-1.4.3.tar.gz", hash = "sha256:fb1bbeac803adea29cedd70781399c99138358c26d05fcbd23c13016b7f5ec65"}, ] typing-extensions = [ - {file = "typing_extensions-4.0.1-py3-none-any.whl", hash = "sha256:7f001e5ac290a0c0401508864c7ec868be4e701886d5b573a9528ed3973d9d3b"}, - {file = "typing_extensions-4.0.1.tar.gz", hash = "sha256:4ca091dea149f945ec56afb48dae714f21e8692ef22a395223bcd328961b6a0e"}, + {file = "typing_extensions-4.4.0-py3-none-any.whl", hash = "sha256:16fa4864408f655d35ec496218b85f79b3437c829e93320c7c9215ccfd92489e"}, + {file = "typing_extensions-4.4.0.tar.gz", hash = "sha256:1511434bb92bf8dd198c12b1cc812e800d4181cfcb867674e0f8279cc93087aa"}, ] urllib3 = [ {file = "urllib3-1.26.8-py2.py3-none-any.whl", hash = "sha256:000ca7f471a233c2251c6c7023ee85305721bfdf18621ebff4fd17a8653427ed"}, diff --git a/cluster_tools/pyproject.toml b/cluster_tools/pyproject.toml index 6b3d509a6..0a6e94a4a 100644 --- a/cluster_tools/pyproject.toml +++ b/cluster_tools/pyproject.toml @@ -8,14 +8,15 @@ license = "MIT" repository = "https://github.com/scalableminds/webknossos-libs" [tool.poetry.dependencies] -python = "^3.6.2" +python = ">=3.7" kubernetes = "^23.3.0" +typing-extensions = "^4.4.0" [tool.poetry.dev-dependencies] black = "21.12b0" icecream = "^2.1.1" isort = "^5.9.3" -mypy = "0.910" +mypy = "0.991" pylint = "^2.12.2" pytest = "^6.2.4" @@ -27,9 +28,9 @@ build-backend = "poetry.core.masonry.api" profile = "black" [tool.mypy] +disallow_untyped_defs = true ignore_missing_imports = true namespace_packages = true strict_equality = true show_error_codes = true -# disallow_untyped_defs = true no_implicit_optional = true diff --git a/cluster_tools/tests/test_all.py b/cluster_tools/tests/test_all.py index ed29d42a0..2a34381b6 100644 --- a/cluster_tools/tests/test_all.py +++ b/cluster_tools/tests/test_all.py @@ -6,6 +6,7 @@ from enum import Enum from functools import partial from pathlib import Path +from typing import List import pytest @@ -13,11 +14,11 @@ # "Worker" functions. -def square(n): +def square(n: float) -> float: return n * n -def sleep(duration): +def sleep(duration: float) -> float: time.sleep(duration) return duration @@ -25,12 +26,12 @@ def sleep(duration): logging.basicConfig() -def raise_if(msg, bool): - if bool: +def raise_if(msg: str, _bool: bool) -> None: + if _bool: raise Exception("raise_if was called with True: {}".format(msg)) -def get_executors(with_debug_sequential=False): +def get_executors(with_debug_sequential: bool = False) -> List[cluster_tools.Executor]: executor_keys = { "slurm", "kubernetes", @@ -46,7 +47,7 @@ def get_executors(with_debug_sequential=False): os.environ["PYTEST_EXECUTORS"].split(",") ) - executors = [] + executors: List[cluster_tools.Executor] = [] if "slurm" in executor_keys: executors.append( cluster_tools.get_executor( @@ -80,7 +81,7 @@ def get_executors(with_debug_sequential=False): @pytest.mark.skip( reason="The test is flaky on the CI for some reason. Disable it for now." ) -def test_uncaught_warning(): +def test_uncaught_warning() -> None: """ This test ensures that there are warnings for "uncaught" futures. """ @@ -97,7 +98,7 @@ def test_uncaught_warning(): cases = [False, True] - def expect_marker(marker, msg, should_exist=True): + def expect_marker(marker: str, msg: str, should_exist: bool = True) -> None: maybe_negate = lambda b: b if should_exist else not b fh.flush() @@ -148,8 +149,8 @@ def expect_marker(marker, msg, should_exist=True): logger.removeHandler(fh) -def test_submit(): - def run_square_numbers(executor): +def test_submit() -> None: + def run_square_numbers(executor: cluster_tools.Executor) -> None: with executor: job_count = 3 job_range = range(job_count) @@ -161,16 +162,16 @@ def run_square_numbers(executor): run_square_numbers(exc) -def get_pid(): +def get_pid() -> int: return os.getpid() -def test_process_id(): +def test_process_id() -> None: outer_pid = os.getpid() - def compare_pids(executor): + def compare_pids(executor: cluster_tools.Executor) -> None: with executor: future = executor.submit(get_pid) inner_pid = future.result() @@ -190,7 +191,7 @@ def compare_pids(executor): compare_pids(exc) -def test_unordered_sleep(): +def test_unordered_sleep() -> None: """Get host identifying information about the servers running our jobs. """ @@ -206,7 +207,7 @@ def test_unordered_sleep(): assert future.result() == duration -def test_unordered_map(): +def test_unordered_map() -> None: for exc in get_executors(): with exc: durations = [15, 1] @@ -220,7 +221,7 @@ def test_unordered_map(): assert result == duration -def test_map_to_futures(): +def test_map_to_futures() -> None: for exc in get_executors(): with exc: durations = [15, 1] @@ -233,11 +234,11 @@ def test_map_to_futures(): if not isinstance(exc, cluster_tools.SequentialExecutor): durations.sort() - for duration, result in zip(durations, results): - assert result == duration + for duration_, result in zip(durations, results): + assert result == duration_ -def test_empty_map_to_futures(): +def test_empty_map_to_futures() -> None: for exc in get_executors(): with exc: futures = exc.map_to_futures(sleep, []) @@ -245,12 +246,11 @@ def test_empty_map_to_futures(): assert len(results) == 0 -def output_pickle_path_getter(tmp_dir, chunk): - +def output_pickle_path_getter(tmp_dir: str, chunk: int) -> Path: return Path(tmp_dir) / f"test_{chunk}.pickle" -def test_map_to_futures_with_pickle_paths(): +def test_map_to_futures_with_pickle_paths() -> None: for exc in get_executors(with_debug_sequential=True): with tempfile.TemporaryDirectory(dir=".") as tmp_dir: @@ -271,17 +271,17 @@ def test_map_to_futures_with_pickle_paths(): assert 2 in results assert 1 in results - for duration in durations: + for duration_ in durations: assert Path( - output_pickle_path_getter(tmp_dir, duration) - ).exists(), f"File for chunk {duration} should exist." + output_pickle_path_getter(tmp_dir, duration_) + ).exists(), f"File for chunk {duration_} should exist." -def test_submit_with_pickle_paths(): +def test_submit_with_pickle_paths() -> None: for (idx, exc) in enumerate(get_executors()): with tempfile.TemporaryDirectory(dir=".") as tmp_dir: - def run_square_numbers(idx, executor): + def run_square_numbers(idx: int, executor: cluster_tools.Executor) -> Path: with executor: job_count = 3 job_range = range(job_count) @@ -291,7 +291,7 @@ def run_square_numbers(idx, executor): output_path = Path(tmp_dir) / f"{idx}_{n}.pickle" cfut_options = {"output_pickle_path": output_path} futures.append( - executor.submit(square, n, __cfut_options=cfut_options) + executor.submit(square, n, __cfut_options=cfut_options) # type: ignore[call-arg] ) for future, job_index in zip(futures, job_range): @@ -302,8 +302,8 @@ def run_square_numbers(idx, executor): assert output_path.exists(), "Output pickle file should exist." -def test_map(): - def run_map(executor): +def test_map() -> None: + def run_map(executor: cluster_tools.Executor) -> None: with executor: result = list(executor.map(square, [2, 3, 4])) assert result == [4, 9, 16] @@ -312,8 +312,8 @@ def run_map(executor): run_map(exc) -def test_map_lazy(): - def run_map(executor): +def test_map_lazy() -> None: + def run_map(executor: cluster_tools.Executor) -> None: with executor: result = executor.map(square, [2, 3, 4]) assert list(result) == [4, 9, 16] @@ -322,8 +322,8 @@ def run_map(executor): run_map(exc) -def test_executor_args(): - def pass_with(exc): +def test_executor_args() -> None: + def pass_with(exc: cluster_tools.Executor) -> None: with exc: pass @@ -337,11 +337,11 @@ class DummyEnum(Enum): PEAR = 2 -def enum_consumer(value): +def enum_consumer(value: DummyEnum) -> None: assert value == DummyEnum.BANANA -def test_cloudpickle_serialization(): +def test_cloudpickle_serialization() -> None: enum_consumer_inner = enum_consumer for fn in [enum_consumer, enum_consumer_inner]: @@ -355,7 +355,7 @@ def test_cloudpickle_serialization(): assert True -def test_map_to_futures_with_debug_sequential(): +def test_map_to_futures_with_debug_sequential() -> None: with cluster_tools.get_executor("debug_sequential") as exc: durations = [4, 1] @@ -370,5 +370,5 @@ def test_map_to_futures_with_debug_sequential(): for i, duration in enumerate(futures): results.append(duration.result()) - for duration, result in zip(durations, results): - assert result == duration + for duration_, result in zip(durations, results): + assert result == duration_ diff --git a/cluster_tools/tests/test_deref_main.py b/cluster_tools/tests/test_deref_main.py index ee2e5efe1..43583bf39 100644 --- a/cluster_tools/tests/test_deref_main.py +++ b/cluster_tools/tests/test_deref_main.py @@ -1,3 +1,7 @@ +from typing import Tuple, Type + +from typing_extensions import Literal + import cluster_tools @@ -5,14 +9,14 @@ class TestClass: pass -def deref_fun_helper(obj): +def deref_fun_helper(obj: Tuple[Type[TestClass], TestClass, int, int]) -> None: clss, inst, one, two = obj assert one == 1 assert two == 2 assert isinstance(inst, clss) -def test_dereferencing_main(): +def test_dereferencing_main() -> None: with cluster_tools.get_executor( "slurm", debug=True, job_resources={"mem": "10M"} ) as executor: diff --git a/cluster_tools/tests/test_kubernetes.py b/cluster_tools/tests/test_kubernetes.py index 52aea9e3a..03803aede 100644 --- a/cluster_tools/tests/test_kubernetes.py +++ b/cluster_tools/tests/test_kubernetes.py @@ -1,17 +1,18 @@ import os +from typing import List import cluster_tools -def square(n): +def square(n: float) -> float: return n * n -def list_dir(path): +def list_dir(path: str) -> List[str]: return os.listdir(path) -def test_simple(): +def test_simple() -> None: with cluster_tools.get_executor( "kubernetes", job_resources={ @@ -26,7 +27,7 @@ def test_simple(): assert list(exec.map(square, [n + 2 for n in range(2)])) == [4, 9] -def test_mounts(): +def test_mounts() -> None: parent_dir = os.path.abspath(os.path.join(os.pardir, os.curdir)) with cluster_tools.get_executor( "kubernetes", diff --git a/cluster_tools/tests/test_multiprocessing.py b/cluster_tools/tests/test_multiprocessing.py index 24b45610a..5626f6aa8 100644 --- a/cluster_tools/tests/test_multiprocessing.py +++ b/cluster_tools/tests/test_multiprocessing.py @@ -9,22 +9,22 @@ logging.basicConfig() -def expect_fork(): +def expect_fork() -> bool: assert mp.get_start_method() == "fork" return True -def expect_forkserver(): +def expect_forkserver() -> bool: assert mp.get_start_method() == "forkserver" return True -def expect_spawn(): +def expect_spawn() -> bool: assert mp.get_start_method() == "spawn" return True -def test_map_with_spawn(): +def test_map_with_spawn() -> None: with cluster_tools.get_executor("multiprocessing", max_workers=5) as executor: assert executor.submit( expect_spawn @@ -52,14 +52,14 @@ def test_map_with_spawn(): ).result(), "Multiprocessing should use `fork` if requested" -def accept_high_mem(data): +def accept_high_mem(data: str) -> int: return len(data) @pytest.mark.skip( reason="This test does not pass on the CI. Probably because the machine does not have enough RAM." ) -def test_high_ram_usage(): +def test_high_ram_usage() -> None: very_long_string = " " * 10 ** 6 * 2500 os.environ["MULTIPROCESSING_VIA_IO"] = "True" @@ -68,7 +68,7 @@ def test_high_ram_usage(): fut1 = executor.submit( accept_high_mem, very_long_string, - __cfut_options={"output_pickle_path": "/tmp/test.pickle"}, + __cfut_options={"output_pickle_path": "/tmp/test.pickle"}, # type: ignore[call-arg] ) assert fut1.result() == len(very_long_string) @@ -79,8 +79,8 @@ def test_high_ram_usage(): del os.environ["MULTIPROCESSING_VIA_IO"] -def test_executor_args(): - def pass_with(exc): +def test_executor_args() -> None: + def pass_with(exc: cluster_tools.MultiprocessingExecutor) -> None: with exc: pass @@ -88,13 +88,13 @@ def pass_with(exc): # Test should succeed if the above lines don't raise an exception -def test_multiprocessing_validation(): - +def test_multiprocessing_validation() -> None: import sys from subprocess import PIPE, STDOUT, Popen cmd = [sys.executable, "guardless_multiprocessing.py"] p = Popen(cmd, shell=False, stdin=PIPE, stdout=PIPE, stderr=STDOUT) + assert p.stdout is not None output = p.stdout.read() assert "current process has finished its bootstrapping phase." in str(output), "S" diff --git a/cluster_tools/tests/test_slurm.py b/cluster_tools/tests/test_slurm.py index 731e4f8d3..12030544e 100644 --- a/cluster_tools/tests/test_slurm.py +++ b/cluster_tools/tests/test_slurm.py @@ -11,20 +11,20 @@ from collections import Counter from functools import partial from pathlib import Path -from typing import Optional +from typing import Any, Optional import pytest import cluster_tools -from cluster_tools.util import call +from cluster_tools._utils.call import call, chcall # "Worker" functions. -def square(n): +def square(n: float) -> float: return n * n -def sleep(duration): +def sleep(duration: float) -> float: time.sleep(duration) return duration @@ -32,12 +32,12 @@ def sleep(duration): logging.basicConfig() -def expect_fork(): +def expect_fork() -> bool: assert mp.get_start_method() == "fork" return True -def test_map_with_spawn(): +def test_map_with_spawn() -> None: with cluster_tools.get_executor( "slurm", max_workers=5, start_method="spawn" ) as executor: @@ -46,16 +46,16 @@ def test_map_with_spawn(): ).result(), "Slurm should ignore provided start_method" -def test_slurm_submit_returns_job_ids(): +def test_slurm_submit_returns_job_ids() -> None: exc = cluster_tools.get_executor("slurm", debug=True) with exc: future = exc.submit(square, 2) - assert isinstance(future.cluster_jobid, int) - assert future.cluster_jobid > 0 + assert isinstance(future.cluster_jobid, str) # type: ignore[attr-defined] + assert int(future.cluster_jobid) > 0 # type: ignore[attr-defined] assert future.result() == 4 -def test_slurm_cfut_dir(): +def test_slurm_cfut_dir() -> None: cfut_dir = "./test_cfut_dir" if os.path.exists(cfut_dir): shutil.rmtree(cfut_dir) @@ -69,7 +69,7 @@ def test_slurm_cfut_dir(): assert len(os.listdir(cfut_dir)) == 2 -def test_slurm_max_submit_user(): +def test_slurm_max_submit_user() -> None: max_submit_jobs = 6 # MaxSubmitJobs can either be defined at the user or at the qos level @@ -77,11 +77,10 @@ def test_slurm_max_submit_user(): executor = cluster_tools.get_executor("slurm", debug=True) original_max_submit_jobs = executor.get_max_submit_jobs() - _, _, exit_code = call( - f"echo y | sacctmgr modify {command} set MaxSubmitJobs={max_submit_jobs}" - ) try: - assert exit_code == 0 + chcall( + f"echo y | sacctmgr modify {command} set MaxSubmitJobs={max_submit_jobs}" + ) new_max_submit_jobs = executor.get_max_submit_jobs() assert new_max_submit_jobs == max_submit_jobs @@ -92,19 +91,16 @@ def test_slurm_max_submit_user(): result = [fut.result() for fut in futures] assert result == [i ** 2 for i in range(10)] - job_ids = {fut.cluster_jobid for fut in futures} + job_ids = {fut.cluster_jobid for fut in futures} # type: ignore[attr-defined] # The 10 work packages should have been scheduled as 2 separate jobs. assert len(job_ids) == 2 finally: - _, _, exit_code = call( - f"echo y | sacctmgr modify {command} set MaxSubmitJobs=-1" - ) - assert exit_code == 0 + chcall(f"echo y | sacctmgr modify {command} set MaxSubmitJobs=-1") reset_max_submit_jobs = executor.get_max_submit_jobs() assert reset_max_submit_jobs == original_max_submit_jobs -def test_slurm_max_submit_user_env(): +def test_slurm_max_submit_user_env() -> None: max_submit_jobs = 4 executor = cluster_tools.get_executor("slurm", debug=True) @@ -122,7 +118,7 @@ def test_slurm_max_submit_user_env(): result = [fut.result() for fut in futures] assert result == [i ** 2 for i in range(10)] - job_ids = {fut.cluster_jobid for fut in futures} + job_ids = {fut.cluster_jobid for fut in futures} # type: ignore[attr-defined] # The 10 work packages should have been scheduled as 3 separate jobs. assert len(job_ids) == 3 finally: @@ -131,13 +127,11 @@ def test_slurm_max_submit_user_env(): assert reset_max_submit_jobs == original_max_submit_jobs -def test_slurm_deferred_submit(): +def test_slurm_deferred_submit() -> None: max_submit_jobs = 1 # Only one job can be scheduled at a time - _, _, exit_code = call( - f"echo y | sacctmgr modify qos normal set MaxSubmitJobs={max_submit_jobs}" - ) + call(f"echo y | sacctmgr modify qos normal set MaxSubmitJobs={max_submit_jobs}") executor = cluster_tools.get_executor("slurm", debug=True) try: @@ -155,28 +149,26 @@ def test_slurm_deferred_submit(): # since only one job is scheduled at a time and each job takes 0.5 seconds assert time_of_result - time_of_start > 1 finally: - _, _, exit_code = call( - "echo y | sacctmgr modify qos normal set MaxSubmitJobs=-1" - ) + call("echo y | sacctmgr modify qos normal set MaxSubmitJobs=-1") -def wait_until_first_job_was_submitted(executor, state: Optional[str] = None): +def wait_until_first_job_was_submitted( + executor: cluster_tools.SlurmExecutor, state: Optional[str] = None +) -> None: # Since the job submission is not synchronous, we need to poll # to find out when the first job was submitted while executor.get_number_of_submitted_jobs(state) <= 0: time.sleep(0.1) -def test_slurm_deferred_submit_shutdown(): +def test_slurm_deferred_submit_shutdown() -> None: # Test that the SlurmExecutor stops scheduling jobs in a separate thread # once it was killed even if the executor was used multiple times and # therefore started multiple job submission threads max_submit_jobs = 1 # Only one job can be scheduled at a time - _, _, exit_code = call( - f"echo y | sacctmgr modify qos normal set MaxSubmitJobs={max_submit_jobs}" - ) + call(f"echo y | sacctmgr modify qos normal set MaxSubmitJobs={max_submit_jobs}") executor = cluster_tools.get_executor("slurm", debug=True) try: @@ -202,12 +194,10 @@ def test_slurm_deferred_submit_shutdown(): time.sleep(0.5) finally: - _, _, exit_code = call( - "echo y | sacctmgr modify qos normal set MaxSubmitJobs=-1" - ) + call("echo y | sacctmgr modify qos normal set MaxSubmitJobs=-1") -def test_slurm_job_canceling_on_shutdown(): +def test_slurm_job_canceling_on_shutdown() -> None: # Test that scheduled jobs are canceled on shutdown, regardless # of whether they are pending or running. max_running_size = 2 @@ -252,7 +242,7 @@ def test_slurm_job_canceling_on_shutdown(): del os.environ["SLURM_MAX_RUNNING_SIZE"] -def test_slurm_number_of_submitted_jobs(): +def test_slurm_number_of_submitted_jobs() -> None: number_of_jobs = 6 executor = cluster_tools.get_executor("slurm", debug=True) @@ -270,19 +260,16 @@ def test_slurm_number_of_submitted_jobs(): assert executor.get_number_of_submitted_jobs() == 0 -def test_slurm_max_array_size(): +def test_slurm_max_array_size() -> None: max_array_size = 2 executor = cluster_tools.get_executor("slurm", debug=True) original_max_array_size = executor.get_max_array_size() command = f"MaxArraySize={max_array_size}" - _, _, exit_code = call( - f"echo -e '{command}' >> /etc/slurm/slurm.conf && scontrol reconfigure" - ) try: - assert exit_code == 0 + chcall(f"echo -e '{command}' >> /etc/slurm/slurm.conf && scontrol reconfigure") new_max_array_size = executor.get_max_array_size() assert new_max_array_size == max_array_size @@ -290,22 +277,19 @@ def test_slurm_max_array_size(): with executor: futures = executor.map_to_futures(square, range(6)) concurrent.futures.wait(futures) - job_ids = [fut.cluster_jobid for fut in futures] + job_ids = [fut.cluster_jobid for fut in futures] # type: ignore[attr-defined] # Count how often each job_id occurs which corresponds to the array size of the job occurences = list(Counter(job_ids).values()) assert all(array_size <= max_array_size for array_size in occurences) finally: - _, _, exit_code = call( - f"sed -i 's/{command}//g' /etc/slurm/slurm.conf && scontrol reconfigure" - ) - assert exit_code == 0 + chcall(f"sed -i 's/{command}//g' /etc/slurm/slurm.conf && scontrol reconfigure") reset_max_array_size = executor.get_max_array_size() assert reset_max_array_size == original_max_array_size -def test_slurm_max_array_size_env(): +def test_slurm_max_array_size_env() -> None: max_array_size = 2 executor = cluster_tools.get_executor("slurm", debug=True) @@ -320,7 +304,7 @@ def test_slurm_max_array_size_env(): with executor: futures = executor.map_to_futures(square, range(6)) concurrent.futures.wait(futures) - job_ids = [fut.cluster_jobid for fut in futures] + job_ids = [fut.cluster_jobid for fut in futures] # type: ignore[attr-defined] # Count how often each job_id occurs which corresponds to the array size of the job occurences = list(Counter(job_ids).values()) @@ -335,12 +319,12 @@ def test_slurm_max_array_size_env(): test_output_str = "Test-Output" -def log(string): +def log(string: str) -> None: logging.debug(string) -def test_pickled_logging(): - def execute_with_log_level(log_level): +def test_pickled_logging() -> None: + def execute_with_log_level(log_level: int) -> str: logging_config = {"level": log_level} with cluster_tools.get_executor( "slurm", @@ -351,7 +335,7 @@ def execute_with_log_level(log_level): fut = executor.submit(log, test_output_str) fut.result() - output = ".cfut/slurmpy.{}.log.stdout".format(fut.cluster_jobid) + output = ".cfut/slurmpy.{}.log.stdout".format(fut.cluster_jobid) # type: ignore[attr-defined] with open(output, "r") as file: return file.read() @@ -363,7 +347,7 @@ def execute_with_log_level(log_level): assert not (test_output_str in info_out) -def test_tailed_logging(): +def test_tailed_logging() -> None: with cluster_tools.get_executor( "slurm", @@ -382,16 +366,15 @@ def test_tailed_logging(): assert "jid" in f.getvalue() -def fail(val): +def fail(val: Any) -> None: raise Exception("Fail()") -def output_pickle_path_getter(tmp_dir, chunk): - +def output_pickle_path_getter(tmp_dir: str, chunk: int) -> Path: return Path(tmp_dir) / f"test_{chunk}.pickle" -def test_preliminary_file_submit(): +def test_preliminary_file_submit() -> None: with tempfile.TemporaryDirectory(dir=".") as tmp_dir: output_pickle_path = Path(tmp_dir) / "test.pickle" @@ -413,20 +396,20 @@ def test_preliminary_file_submit(): assert not output_pickle_path.exists(), "Final output file should not exist" # Schedule succeeding job with same output path - fut = executor.submit( + fut_2 = executor.submit( square, 3, - __cfut_options={"output_pickle_path": str(output_pickle_path)}, + __cfut_options={"output_pickle_path": str(output_pickle_path)}, # type: ignore[call-arg] ) - assert fut.result() == 9 + assert fut_2.result() == 9 assert output_pickle_path.exists(), "Final output file should exist" assert ( not preliminary_output_path.exists() ), "Preliminary output file should not exist anymore" -def test_executor_args(): - def pass_with(exc): +def test_executor_args() -> None: + def pass_with(exc: cluster_tools.SlurmExecutor) -> None: with exc: pass @@ -438,7 +421,7 @@ def pass_with(exc): # Test should succeed if the above lines don't raise an exception -def test_preliminary_file_map(): +def test_preliminary_file_map() -> None: a_range = range(1, 4) @@ -469,13 +452,13 @@ def test_preliminary_file_map(): ), "Final output file should not exist" # Schedule succeeding jobs with same output paths - futs = executor.map_to_futures( + futs_2 = executor.map_to_futures( square, list(a_range), output_pickle_path_getter=partial(output_pickle_path_getter, tmp_dir), ) - for (fut, job_index) in zip(futs, a_range): - assert fut.result() == square(job_index) + for (fut_2, job_index) in zip(futs_2, a_range): + assert fut_2.result() == square(job_index) for idx in a_range: output_pickle_path = Path(output_pickle_path_getter(tmp_dir, idx)) diff --git a/webknossos/webknossos/dataset/_utils/pims_images.py b/webknossos/webknossos/dataset/_utils/pims_images.py index 533480885..65213fc2f 100644 --- a/webknossos/webknossos/dataset/_utils/pims_images.py +++ b/webknossos/webknossos/dataset/_utils/pims_images.py @@ -389,6 +389,10 @@ def _try_open_bioformats_images_raw( # since it does not include the necessary loci_tools.jar. # Updates to support newer bioformats jars with pims are in PR # https://github.com/soft-matter/pims/pull/403 + + # This is also part of the worker dockerfile to cache the + # jar in the image, please update Dockerfile.worker in the + # voxelytics repo accordingly when editing this. pims.bioformats.download_jar(version="6.7.0") if "*" in str(original_images) or isinstance(original_images, list):