Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement module upload plugin (#8698) #8884

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 170 additions & 0 deletions distributed/diagnostics/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,21 @@
import functools
import logging
import os
import shutil
import socket
import subprocess
import sys
import tempfile
import uuid
import zipfile
from collections.abc import Awaitable
from contextlib import contextmanager
from importlib.util import find_spec
from io import BytesIO
from typing import TYPE_CHECKING, Any, Callable, ClassVar
from types import ModuleType
from typing import Any, Tuple
from pathlib import Path

from dask.typing import Key
from dask.utils import funcname, tmpfile
Expand All @@ -29,6 +36,7 @@
from distributed.scheduler import TaskStateState as SchedulerTaskStateState
from distributed.worker import Worker
from distributed.worker_state_machine import TaskStateState as WorkerTaskStateState
from distributed.node import ServerNode

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -1051,3 +1059,165 @@ def setup(self, worker):

def teardown(self, worker):
self._exit_stack.close()


@contextmanager
def serialize_module(
module: ModuleType, exclude: Tuple[str] = ("__pycache__", ".DS_Store")
) -> Path:
module_path = Path(module.__file__)

if module_path.stem == "__init__":
# In case of package we serialize the whole package
module_path = module_path.parent
if "." in module.__name__:
# TODO: the problem is that we serialize the `package.module`, as module.egg that contains module.py,
# but it should contain the whole structure of the package (package/module.py)
raise Exception(
f"Plugin supports only top-level packages or single-file modules. You provided `{module.__name__}`, try `{module.__name__.split('.')[0]}`."
)

# In case of single file we don't need to serialize anything

with tempfile.TemporaryDirectory() as tmp:
package_name = module_path.name

package_copy_path = Path(tmp).joinpath(package_name)
if module_path.is_dir():
copied_package = Path(
shutil.copytree(
module_path,
package_copy_path,
ignore=shutil.ignore_patterns(f"{package_name}.zip", *exclude),
)
)
else:
copied_package = Path(shutil.copy2(module_path, package_copy_path))

archive_path = shutil.make_archive(
# output path including a name w/o extension
base_name=str(copied_package),
format="zip",
# chroot
root_dir=copied_package.parent,
# Name of the directory to archive and a common prefix of all files and directories in the archive
base_dir=package_name,
)

egg_file = shutil.move(archive_path, package_copy_path.with_suffix(".egg"))

# zip file handler
zip = zipfile.ZipFile(egg_file)
# list available files in the container
logger.debug(
"The egg file %s contains the following files %s",
str(egg_file),
str(zip.namelist()),
)

logger.info("Created an egg file %s from %s", str(egg_file), str(module_path))

yield Path(egg_file)


class AbstractUploadModulePlugin:
def __init__(self, module: ModuleType):
self._module_name = module.__name__
self._data: bytes
self._filepath: Path
self._filename: str
with serialize_module(module) as filepath:
self._filename = filepath.name
with open(filepath, "rb") as f:
self._data = f.read()

async def _upload_file(self, node: ServerNode):
response = await node.upload_file(self._filename, self._data, load=True)
assert len(self._data) == response["nbytes"]

async def _upload(self, node: ServerNode):
import zipfile
import sys
try:
from IPython.extensions.autoreload import superreload
except ImportError:
superreload = lambda x: x

# Try to find already loaded module
module = (
sys.modules[self._module_name] if self._module_name in sys.modules else None
)
# Try to find module on disk
module_spec = find_spec(self._module_name)

if not module_spec and not module:
# If module does not exist we keep it as egg file and load it.
logger.info(
'Uploading a new module "%s" to "%s" on %s "%s"',
self._module_name,
str(self._filename),
"worker" if isinstance(node, Worker) else "scheduler",
node.id,
)
await self._upload_file(node)
return

if module:
module_path = self._get_module_dir(module)
else:
module_path = Path(module_spec.origin)

if ".egg" in str(module_path):
# Update the previously uploaded egg module and reload it.
logger.info(
'Uploading an update for a previously uploaded a new module "%s" to "%s" on %s "%s"',
self._module_name,
str(self._filename),
"worker" if isinstance(node, Worker) else "scheduler",
node.id,
)
await self._upload_file(node)
return

with zipfile.ZipFile(BytesIO(self._data), "r") as zip_ref:
# In case, we received egg file for module that exists on node in source code,
# we overwrite each file separately by extracting it from the egg.
logger.info(
'Uploading an update for an existing module "%s" in "%s" on %s "%s"',
self._module_name,
str(module_path.parent),
"worker" if isinstance(node, Worker) else "scheduler",
node.id,
)
zip_ref.extractall(module_path.parent)

# TODO: Do we really need Jupyter's `superreload` here instead of built-in Python's function?
if self._module_name in sys.modules:
# Reload module if it is already loaded
superreload(sys.modules[self._module_name])

@classmethod
def _get_module_dir(cls, module: ModuleType) -> Path:
"""Get the directory of the module."""
module_path = Path(sys.modules[module.__name__].__file__)

if module_path.stem == "__init__":
# In case of package we serialize the whole package
return module_path.parent

# In case of single file we don't need to serialize anything
return module_path


class UploadModule(WorkerPlugin, AbstractUploadModulePlugin):
name = "upload_module"

async def setup(self, worker: Worker):
await self._upload(worker)


class SchedulerUploadModule(SchedulerPlugin, AbstractUploadModulePlugin):
name = "upload_module"

async def start(self, scheduler: Scheduler) -> None:
await self._upload(scheduler)
Loading