diff --git a/distributed/diagnostics/plugin.py b/distributed/diagnostics/plugin.py index 71c6cc1ed9..8698396084 100644 --- a/distributed/diagnostics/plugin.py +++ b/distributed/diagnostics/plugin.py @@ -5,6 +5,7 @@ import functools import logging import os +import shutil import socket import subprocess import sys @@ -12,7 +13,13 @@ import uuid import zipfile from collections.abc import Awaitable +from contextlib import contextmanager +from importlib.util import find_spec +from io import BytesIO from typing import TYPE_CHECKING, Any, Callable, ClassVar +from types import ModuleType +from typing import Any, Tuple +from pathlib import Path from dask.typing import Key from dask.utils import funcname, tmpfile @@ -29,6 +36,7 @@ from distributed.scheduler import TaskStateState as SchedulerTaskStateState from distributed.worker import Worker from distributed.worker_state_machine import TaskStateState as WorkerTaskStateState + from distributed.node import ServerNode logger = logging.getLogger(__name__) @@ -1051,3 +1059,165 @@ def setup(self, worker): def teardown(self, worker): self._exit_stack.close() + + +@contextmanager +def serialize_module( + module: ModuleType, exclude: Tuple[str] = ("__pycache__", ".DS_Store") +) -> Path: + module_path = Path(module.__file__) + + if module_path.stem == "__init__": + # In case of package we serialize the whole package + module_path = module_path.parent + if "." in module.__name__: + # TODO: the problem is that we serialize the `package.module`, as module.egg that contains module.py, + # but it should contain the whole structure of the package (package/module.py) + raise Exception( + f"Plugin supports only top-level packages or single-file modules. You provided `{module.__name__}`, try `{module.__name__.split('.')[0]}`." + ) + + # In case of single file we don't need to serialize anything + + with tempfile.TemporaryDirectory() as tmp: + package_name = module_path.name + + package_copy_path = Path(tmp).joinpath(package_name) + if module_path.is_dir(): + copied_package = Path( + shutil.copytree( + module_path, + package_copy_path, + ignore=shutil.ignore_patterns(f"{package_name}.zip", *exclude), + ) + ) + else: + copied_package = Path(shutil.copy2(module_path, package_copy_path)) + + archive_path = shutil.make_archive( + # output path including a name w/o extension + base_name=str(copied_package), + format="zip", + # chroot + root_dir=copied_package.parent, + # Name of the directory to archive and a common prefix of all files and directories in the archive + base_dir=package_name, + ) + + egg_file = shutil.move(archive_path, package_copy_path.with_suffix(".egg")) + + # zip file handler + zip = zipfile.ZipFile(egg_file) + # list available files in the container + logger.debug( + "The egg file %s contains the following files %s", + str(egg_file), + str(zip.namelist()), + ) + + logger.info("Created an egg file %s from %s", str(egg_file), str(module_path)) + + yield Path(egg_file) + + +class AbstractUploadModulePlugin: + def __init__(self, module: ModuleType): + self._module_name = module.__name__ + self._data: bytes + self._filepath: Path + self._filename: str + with serialize_module(module) as filepath: + self._filename = filepath.name + with open(filepath, "rb") as f: + self._data = f.read() + + async def _upload_file(self, node: ServerNode): + response = await node.upload_file(self._filename, self._data, load=True) + assert len(self._data) == response["nbytes"] + + async def _upload(self, node: ServerNode): + import zipfile + import sys + try: + from IPython.extensions.autoreload import superreload + except ImportError: + superreload = lambda x: x + + # Try to find already loaded module + module = ( + sys.modules[self._module_name] if self._module_name in sys.modules else None + ) + # Try to find module on disk + module_spec = find_spec(self._module_name) + + if not module_spec and not module: + # If module does not exist we keep it as egg file and load it. + logger.info( + 'Uploading a new module "%s" to "%s" on %s "%s"', + self._module_name, + str(self._filename), + "worker" if isinstance(node, Worker) else "scheduler", + node.id, + ) + await self._upload_file(node) + return + + if module: + module_path = self._get_module_dir(module) + else: + module_path = Path(module_spec.origin) + + if ".egg" in str(module_path): + # Update the previously uploaded egg module and reload it. + logger.info( + 'Uploading an update for a previously uploaded a new module "%s" to "%s" on %s "%s"', + self._module_name, + str(self._filename), + "worker" if isinstance(node, Worker) else "scheduler", + node.id, + ) + await self._upload_file(node) + return + + with zipfile.ZipFile(BytesIO(self._data), "r") as zip_ref: + # In case, we received egg file for module that exists on node in source code, + # we overwrite each file separately by extracting it from the egg. + logger.info( + 'Uploading an update for an existing module "%s" in "%s" on %s "%s"', + self._module_name, + str(module_path.parent), + "worker" if isinstance(node, Worker) else "scheduler", + node.id, + ) + zip_ref.extractall(module_path.parent) + + # TODO: Do we really need Jupyter's `superreload` here instead of built-in Python's function? + if self._module_name in sys.modules: + # Reload module if it is already loaded + superreload(sys.modules[self._module_name]) + + @classmethod + def _get_module_dir(cls, module: ModuleType) -> Path: + """Get the directory of the module.""" + module_path = Path(sys.modules[module.__name__].__file__) + + if module_path.stem == "__init__": + # In case of package we serialize the whole package + return module_path.parent + + # In case of single file we don't need to serialize anything + return module_path + + +class UploadModule(WorkerPlugin, AbstractUploadModulePlugin): + name = "upload_module" + + async def setup(self, worker: Worker): + await self._upload(worker) + + +class SchedulerUploadModule(SchedulerPlugin, AbstractUploadModulePlugin): + name = "upload_module" + + async def start(self, scheduler: Scheduler) -> None: + await self._upload(scheduler) \ No newline at end of file