diff --git a/lxm3/xm_cluster/config.py b/lxm3/xm_cluster/config.py index 5f9d51a..6bb7970 100644 --- a/lxm3/xm_cluster/config.py +++ b/lxm3/xm_cluster/config.py @@ -1,6 +1,5 @@ import functools import os -from collections import UserDict from typing import Any, Dict, Optional import appdirs @@ -12,7 +11,101 @@ ) -class Config(UserDict): +class SingularitySettings: + def __init__(self, data) -> None: + self._data = data + + def __repr__(self) -> str: + return repr(self._data) + + @property + def command(self) -> str: + return self._data.get("cmd", "singularity") + + @property + def binds(self) -> Dict[str, str]: + binds = self._data.get("binds", []) + return {bind["src"]: bind["dest"] for bind in binds} + + @property + def env(self) -> Dict[str, str]: + return self._data.get("env", {}) + + +class LocalSettings: + def __init__(self, data) -> None: + self._data = data + + def __repr__(self) -> str: + return repr(self._data) + + @property + def storage_root(self) -> str: + return self._data["storage"]["staging"] + + @property + def env(self) -> Dict[str, str]: + return self._data.get("env", {}) + + @property + def singularity(self) -> SingularitySettings: + return SingularitySettings(self._data.get("singularity", {})) + + +class ClusterSettings: + def __init__(self, data) -> None: + self._data = data + + def __repr__(self) -> str: + return repr(self._data) + + @property + def storage_root(self): + return self._data["storage"]["staging"] + + @property + def hostname(self): + return self._data.get("server", None) + + @property + def user(self): + return self._data.get("user", None) + + @property + def ssh_config(self): + connect_kwargs = {} + proxycommand = self._data.get("proxycommand", None) + if proxycommand is not None: + import paramiko + + connect_kwargs["sock"] = paramiko.ProxyCommand(proxycommand) + + ssh_private_key = self._data.get("ssh_private_key", None) + if ssh_private_key is not None: + connect_kwargs["key_filename"] = os.path.expanduser(ssh_private_key) + + password = self._data.get("password", None) + if password is not None: + connect_kwargs["password"] = password + + return connect_kwargs + + @property + def env(self) -> Dict[str, str]: + return self._data.get("env", {}) + + @property + def singularity(self) -> SingularitySettings: + return SingularitySettings(self._data.get("singularity", {})) + + +class Config: + def __init__(self, data) -> None: + self._data = data + + def __repr__(self) -> Any: + return repr(self._data) + @classmethod def from_file(cls, path: str) -> "Config": with open(path, "rt") as f: @@ -28,44 +121,27 @@ def project(self) -> Optional[str]: project = os.environ.get("LXM_PROJECT", None) if project is not None: return project - return self.data.get("project", None) + return self._data.get("project", None) - def local_config(self) -> Dict[str, Any]: - return self.data["local"] + def set_project(self, project): + self._data["project"] = project + + def local_settings(self) -> LocalSettings: + return LocalSettings(self._data["local"]) def default_cluster(self) -> str: cluster = os.environ.get("LXM_CLUSTER", None) if cluster is None: - cluster = self.data["clusters"][0]["name"] + cluster = self._data["clusters"][0]["name"] return cluster - def get_cluster_settings(self): + def cluster_settings(self) -> ClusterSettings: location = self.default_cluster() - clusters = {cluster["name"]: cluster for cluster in self.data["clusters"]} + clusters = {cluster["name"]: cluster for cluster in self._data["clusters"]} if location not in clusters: raise ValueError("Unknown cluster") cluster_config = clusters[location] - storage_root = cluster_config["storage"]["staging"] - hostname = cluster_config.get("server", None) - user = cluster_config.get("user", None) - - connect_kwargs = {} - - proxycommand = cluster_config.get("proxycommand", None) - if proxycommand is not None: - import paramiko - - connect_kwargs["sock"] = paramiko.ProxyCommand(proxycommand) - - ssh_private_key = cluster_config.get("ssh_private_key", None) - if ssh_private_key is not None: - connect_kwargs["key_filename"] = os.path.expanduser(ssh_private_key) - - password = cluster_config.get("password", None) - if password is not None: - connect_kwargs["password"] = password - - return storage_root, hostname, user, connect_kwargs + return ClusterSettings(cluster_config) @functools.lru_cache() diff --git a/lxm3/xm_cluster/execution/common.py b/lxm3/xm_cluster/execution/common.py index eb12926..ffb293c 100644 --- a/lxm3/xm_cluster/execution/common.py +++ b/lxm3/xm_cluster/execution/common.py @@ -1,12 +1,21 @@ +import copy import os -from typing import List, Optional +from typing import Dict, List, Optional, Union from lxm3 import xm +from lxm3.xm_cluster import config as config_lib from lxm3.xm_cluster import executables from lxm3.xm_cluster import executors from lxm3.xm_cluster.execution import job_script +def _apply_env_overrides(job: xm.Job, env_overrides: Dict[str, str]): + job.env_vars = copy.copy(job.env_vars) + for k, v in env_overrides.items(): + if k not in job.env_vars: + job.env_vars[k] = v + + def create_array_job( *, executable: executables.Command, @@ -19,7 +28,26 @@ def create_array_job( task_id_var_name: str, setup: str, header: str, + settings: Optional[ + Union[config_lib.LocalSettings, config_lib.ClusterSettings] + ] = None, ): + if settings is not None: + # Apply cluster env overrides + for job in jobs: + _apply_env_overrides(job, settings.env) + + if singularity_image is not None and settings is not None: + for job in jobs: + _apply_env_overrides(job, settings.singularity.env) + + if singularity_options is None: + singularity_options = executors.SingularityOptions() + else: + singularity_options = copy.deepcopy(singularity_options) + for src, dst in settings.singularity.binds.items(): + singularity_options.bind.update({src: dst}) + array_wrapper = _create_array_wrapper(executable, jobs, task_offset) deploy_archive_path = executable.resource_uri setup_cmds = """ diff --git a/lxm3/xm_cluster/execution/gridengine.py b/lxm3/xm_cluster/execution/gridengine.py index 7858384..0443016 100644 --- a/lxm3/xm_cluster/execution/gridengine.py +++ b/lxm3/xm_cluster/execution/gridengine.py @@ -158,7 +158,12 @@ def _get_setup_cmds( return "\n".join(cmds) -def create_job_script(jobs: List[xm.Job], job_name: str, job_script_dir: str) -> str: +def create_job_script( + cluster_settings: config_lib.ClusterSettings, + jobs: List[xm.Job], + job_name: str, + job_script_dir: str, +) -> str: executable = jobs[0].executable executor = jobs[0].executor @@ -180,6 +185,7 @@ def create_job_script(jobs: List[xm.Job], job_name: str, job_script_dir: str) -> task_id_var_name=_TASK_ID_VAR_NAME, setup=setup, header=header, + settings=cluster_settings, ) @@ -200,31 +206,28 @@ async def launch( if len(jobs) < 1: return [] - ( - storage_root, - hostname, - user, - connect_kwargs, - ) = config_lib.default().get_cluster_settings() + cluster_settings = config_lib.default().cluster_settings() artifact = artifacts.create_artifact_store( - storage_root, - hostname=hostname, - user=user, + cluster_settings.storage_root, + hostname=cluster_settings.hostname, + user=cluster_settings.user, project=config.project(), - connect_kwargs=connect_kwargs, + connect_kwargs=cluster_settings.ssh_config, ) version = datetime.datetime.now().strftime("%Y%m%d.%H%M%S") job_name = f"job-{version}" job_script_dir = artifact.job_path(job_name) - job_script_content = create_job_script(jobs, job_name, job_script_dir) + job_script_content = create_job_script( + cluster_settings, jobs, job_name, job_script_dir + ) artifact.deploy_job_scripts(job_name, job_script_content) job_script_path = os.path.join(job_script_dir, job_script.JOB_SCRIPT_NAME) - client = gridengine.Client(hostname, user) - console.log(f"Launching {len(jobs)} jobs on {hostname}") + client = gridengine.Client(cluster_settings.hostname, cluster_settings.user) + console.log(f"Launching {len(jobs)} jobs on {cluster_settings.hostname}") console.log(f"Launch with command:\n qsub {job_script_path}") group = client.launch(job_script_path) diff --git a/lxm3/xm_cluster/execution/local.py b/lxm3/xm_cluster/execution/local.py index 98947f7..3ccc292 100644 --- a/lxm3/xm_cluster/execution/local.py +++ b/lxm3/xm_cluster/execution/local.py @@ -64,6 +64,7 @@ def _get_setup_cmds( def create_job_script( + local_settings: config_lib.LocalSettings, artifact: artifacts.Artifact, jobs: List[xm.Job], version: Optional[str] = None, @@ -94,6 +95,7 @@ def create_job_script( task_id_var_name=_TASK_ID_VAR_NAME, setup=setup, header=header, + settings=local_settings, ) @@ -112,12 +114,12 @@ async def launch(config: config_lib.Config, jobs: List[xm.Job]): if len(jobs) < 1: return [] - local_config = config.local_config() + local_config = config.local_settings() artifact = artifacts.LocalArtifact( - local_config["storage"]["staging"], project=config.project() + local_config.storage_root, project=config.project() ) version = datetime.datetime.now().strftime("%Y%m%d.%H%M%S") - job_script_content = create_job_script(artifact, jobs, version) + job_script_content = create_job_script(local_config, artifact, jobs, version) job_name = f"job-{version}" job_script_dir = artifact.job_path(job_name) diff --git a/lxm3/xm_cluster/execution/slurm.py b/lxm3/xm_cluster/execution/slurm.py index 2673402..d28893b 100644 --- a/lxm3/xm_cluster/execution/slurm.py +++ b/lxm3/xm_cluster/execution/slurm.py @@ -121,7 +121,12 @@ def _get_setup_cmds( return "\n".join(cmds) -def create_job_script(jobs: List[xm.Job], job_name, job_script_dir) -> str: +def create_job_script( + cluster_settings: config_lib.ClusterSettings, + jobs: List[xm.Job], + job_name, + job_script_dir, +) -> str: executable = jobs[0].executable executor = jobs[0].executor @@ -143,6 +148,7 @@ def create_job_script(jobs: List[xm.Job], job_name, job_script_dir) -> str: task_id_var_name=_TASK_ID_VAR_NAME, setup=setup, header=header, + settings=cluster_settings, ) @@ -166,31 +172,30 @@ async def launch(config: config_lib.Config, jobs: List[xm.Job]) -> List[SlurmHan "Only GridEngine executors are supported by the gridengine backend." ) - ( - storage_root, - hostname, - user, - connect_kwargs, - ) = config_lib.default().get_cluster_settings() + cluster_settings = config_lib.default().cluster_settings() artifact = artifacts.create_artifact_store( - storage_root, - hostname=hostname, - user=user, + cluster_settings.storage_root, + hostname=cluster_settings.hostname, + user=cluster_settings.user, project=config.project(), - connect_kwargs=connect_kwargs, + connect_kwargs=cluster_settings.ssh_config, ) version = datetime.datetime.now().strftime("%Y%m%d.%H%M%S") job_name = f"job-{version}" job_script_dir = artifact.job_path(job_name) - job_script_content = create_job_script(jobs, job_name, job_script_dir) + job_script_content = create_job_script( + cluster_settings, jobs, job_name, job_script_dir + ) artifact.deploy_job_scripts(job_name, job_script_content) job_script_path = os.path.join(job_script_dir, job_script.JOB_SCRIPT_NAME) console.log(f"Launch with command:\n sbatch {job_script_path}") - client = slurm.Client(hostname=hostname, username=user) + client = slurm.Client( + hostname=cluster_settings.hostname, username=cluster_settings.user + ) job_id = client.launch(job_script_path) common.write_job_id(artifact, job_script_path, str(job_id)) diff --git a/lxm3/xm_cluster/executors.py b/lxm3/xm_cluster/executors.py index 405c7e0..13028e9 100644 --- a/lxm3/xm_cluster/executors.py +++ b/lxm3/xm_cluster/executors.py @@ -38,7 +38,7 @@ class SingularityOptions(xm.ExecutorSpec): lxm3 will do it for you. """ - bind: Optional[Dict[str, str]] = None + bind: Dict[str, str] = attr.Factory(dict) extra_options: Sequence[str] = attr.Factory(list) diff --git a/lxm3/xm_cluster/experiment.py b/lxm3/xm_cluster/experiment.py index 47f3c94..974e704 100644 --- a/lxm3/xm_cluster/experiment.py +++ b/lxm3/xm_cluster/experiment.py @@ -164,7 +164,7 @@ class ClusterExperiment(xm.Experiment): def __init__( self, experiment_title: str, - config: Mapping[str, Any], + config: config_lib.Config, vcs: Optional[vcsinfo.VCS] = None, ) -> None: super().__init__() @@ -329,6 +329,6 @@ def create_experiment( vcs = _load_vcsinfo() if not config.project() and vcs is not None: - config.data["project"] = vcs.name + config.set_project(vcs.name) return ClusterExperiment(experiment_title, config=config) diff --git a/lxm3/xm_cluster/packaging/__init__.py b/lxm3/xm_cluster/packaging/__init__.py index 44654a4..307fb83 100644 --- a/lxm3/xm_cluster/packaging/__init__.py +++ b/lxm3/xm_cluster/packaging/__init__.py @@ -290,9 +290,9 @@ def _package_for_local_executor( del executor_spec config = config_lib.default() - local_config = config_lib.default().local_config() + local_settings = config_lib.default().local_settings() artifact = artifacts.LocalArtifact( - local_config["storage"]["staging"], project=config.project() + local_settings.storage_root, project=config.project() ) return _PACKAGING_ROUTER(packageable.executable_spec, packageable, artifact) @@ -302,14 +302,14 @@ def _package_for_gridengine_executor( ): del executor_spec config = config_lib.default() - storage_root, hostname, user, connect_kwargs = config.get_cluster_settings() + cluster_settings = config.cluster_settings() artifact = artifacts.create_artifact_store( - storage_root, - hostname=hostname, - user=user, + cluster_settings.storage_root, + hostname=cluster_settings.hostname, + user=cluster_settings.user, project=config.project(), - connect_kwargs=connect_kwargs, + connect_kwargs=cluster_settings.ssh_config, ) return _PACKAGING_ROUTER(packageable.executable_spec, packageable, artifact) @@ -319,14 +319,14 @@ def _package_for_slurm_executor( ): del executor_spec config = config_lib.default() - storage_root, hostname, user, connect_kwargs = config.get_cluster_settings() + cluster_settings = config.cluster_settings() artifact = artifacts.create_artifact_store( - storage_root, - hostname=hostname, - user=user, + cluster_settings.storage_root, + hostname=cluster_settings.hostname, + user=cluster_settings.user, project=config.project(), - connect_kwargs=connect_kwargs, + connect_kwargs=cluster_settings.ssh_config, ) return _PACKAGING_ROUTER(packageable.executable_spec, packageable, artifact) diff --git a/tests/config_test.py b/tests/config_test.py index d68a172..d0edf7e 100644 --- a/tests/config_test.py +++ b/tests/config_test.py @@ -38,12 +38,12 @@ def _test_config(): class ConfigTest(parameterized.TestCase): def test_config(self): config = _test_config() - self.assertTrue(isinstance(config["clusters"], list)) - self.assertEqual(config["clusters"][0]["name"], "cs") + self.assertTrue(isinstance(config._data["clusters"], list)) + self.assertEqual(config._data["clusters"][0]["name"], "cs") def test_local_config(self): config = _test_config() - self.assertEqual(config.local_config()["storage"], {"staging": ".lxm"}) + self.assertEqual(config.local_settings().storage_root, ".lxm") def test_default_cluster(self): config = _test_config() @@ -52,7 +52,7 @@ def test_default_cluster(self): self.assertEqual(config.default_cluster(), "myriad") def test_config_project(self): - config = config_lib.Config() + config = config_lib.Config.from_string("") self.assertEqual(config.project(), None) with unittest.mock.patch.dict("os.environ", {"LXM_PROJECT": "test"}): self.assertEqual(config.project(), "test")