Skip to content

Commit

Permalink
Support global setting overrides (#14)
Browse files Browse the repository at this point in the history
* Rework configuration
* Support overriding settings from config file
  • Loading branch information
ethanluoyc authored Oct 31, 2023
1 parent 3e5a079 commit 2da5bd3
Show file tree
Hide file tree
Showing 9 changed files with 193 additions and 79 deletions.
134 changes: 105 additions & 29 deletions lxm3/xm_cluster/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import functools
import os
from collections import UserDict
from typing import Any, Dict, Optional

import appdirs
Expand All @@ -12,7 +11,101 @@
)


class Config(UserDict):
class SingularitySettings:
def __init__(self, data) -> None:
self._data = data

def __repr__(self) -> str:
return repr(self._data)

@property
def command(self) -> str:
return self._data.get("cmd", "singularity")

@property
def binds(self) -> Dict[str, str]:
binds = self._data.get("binds", [])
return {bind["src"]: bind["dest"] for bind in binds}

@property
def env(self) -> Dict[str, str]:
return self._data.get("env", {})


class LocalSettings:
def __init__(self, data) -> None:
self._data = data

def __repr__(self) -> str:
return repr(self._data)

@property
def storage_root(self) -> str:
return self._data["storage"]["staging"]

@property
def env(self) -> Dict[str, str]:
return self._data.get("env", {})

@property
def singularity(self) -> SingularitySettings:
return SingularitySettings(self._data.get("singularity", {}))


class ClusterSettings:
def __init__(self, data) -> None:
self._data = data

def __repr__(self) -> str:
return repr(self._data)

@property
def storage_root(self):
return self._data["storage"]["staging"]

@property
def hostname(self):
return self._data.get("server", None)

@property
def user(self):
return self._data.get("user", None)

@property
def ssh_config(self):
connect_kwargs = {}
proxycommand = self._data.get("proxycommand", None)
if proxycommand is not None:
import paramiko

connect_kwargs["sock"] = paramiko.ProxyCommand(proxycommand)

ssh_private_key = self._data.get("ssh_private_key", None)
if ssh_private_key is not None:
connect_kwargs["key_filename"] = os.path.expanduser(ssh_private_key)

password = self._data.get("password", None)
if password is not None:
connect_kwargs["password"] = password

return connect_kwargs

@property
def env(self) -> Dict[str, str]:
return self._data.get("env", {})

@property
def singularity(self) -> SingularitySettings:
return SingularitySettings(self._data.get("singularity", {}))


class Config:
def __init__(self, data) -> None:
self._data = data

def __repr__(self) -> Any:
return repr(self._data)

@classmethod
def from_file(cls, path: str) -> "Config":
with open(path, "rt") as f:
Expand All @@ -28,44 +121,27 @@ def project(self) -> Optional[str]:
project = os.environ.get("LXM_PROJECT", None)
if project is not None:
return project
return self.data.get("project", None)
return self._data.get("project", None)

def local_config(self) -> Dict[str, Any]:
return self.data["local"]
def set_project(self, project):
self._data["project"] = project

def local_settings(self) -> LocalSettings:
return LocalSettings(self._data["local"])

def default_cluster(self) -> str:
cluster = os.environ.get("LXM_CLUSTER", None)
if cluster is None:
cluster = self.data["clusters"][0]["name"]
cluster = self._data["clusters"][0]["name"]
return cluster

def get_cluster_settings(self):
def cluster_settings(self) -> ClusterSettings:
location = self.default_cluster()
clusters = {cluster["name"]: cluster for cluster in self.data["clusters"]}
clusters = {cluster["name"]: cluster for cluster in self._data["clusters"]}
if location not in clusters:
raise ValueError("Unknown cluster")
cluster_config = clusters[location]
storage_root = cluster_config["storage"]["staging"]
hostname = cluster_config.get("server", None)
user = cluster_config.get("user", None)

connect_kwargs = {}

proxycommand = cluster_config.get("proxycommand", None)
if proxycommand is not None:
import paramiko

connect_kwargs["sock"] = paramiko.ProxyCommand(proxycommand)

ssh_private_key = cluster_config.get("ssh_private_key", None)
if ssh_private_key is not None:
connect_kwargs["key_filename"] = os.path.expanduser(ssh_private_key)

password = cluster_config.get("password", None)
if password is not None:
connect_kwargs["password"] = password

return storage_root, hostname, user, connect_kwargs
return ClusterSettings(cluster_config)


@functools.lru_cache()
Expand Down
30 changes: 29 additions & 1 deletion lxm3/xm_cluster/execution/common.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
import copy
import os
from typing import List, Optional
from typing import Dict, List, Optional, Union

from lxm3 import xm
from lxm3.xm_cluster import config as config_lib
from lxm3.xm_cluster import executables
from lxm3.xm_cluster import executors
from lxm3.xm_cluster.execution import job_script


def _apply_env_overrides(job: xm.Job, env_overrides: Dict[str, str]):
job.env_vars = copy.copy(job.env_vars)
for k, v in env_overrides.items():
if k not in job.env_vars:
job.env_vars[k] = v


def create_array_job(
*,
executable: executables.Command,
Expand All @@ -19,7 +28,26 @@ def create_array_job(
task_id_var_name: str,
setup: str,
header: str,
settings: Optional[
Union[config_lib.LocalSettings, config_lib.ClusterSettings]
] = None,
):
if settings is not None:
# Apply cluster env overrides
for job in jobs:
_apply_env_overrides(job, settings.env)

if singularity_image is not None and settings is not None:
for job in jobs:
_apply_env_overrides(job, settings.singularity.env)

if singularity_options is None:
singularity_options = executors.SingularityOptions()
else:
singularity_options = copy.deepcopy(singularity_options)
for src, dst in settings.singularity.binds.items():
singularity_options.bind.update({src: dst})

array_wrapper = _create_array_wrapper(executable, jobs, task_offset)
deploy_archive_path = executable.resource_uri
setup_cmds = """
Expand Down
31 changes: 17 additions & 14 deletions lxm3/xm_cluster/execution/gridengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,12 @@ def _get_setup_cmds(
return "\n".join(cmds)


def create_job_script(jobs: List[xm.Job], job_name: str, job_script_dir: str) -> str:
def create_job_script(
cluster_settings: config_lib.ClusterSettings,
jobs: List[xm.Job],
job_name: str,
job_script_dir: str,
) -> str:
executable = jobs[0].executable
executor = jobs[0].executor

Expand All @@ -180,6 +185,7 @@ def create_job_script(jobs: List[xm.Job], job_name: str, job_script_dir: str) ->
task_id_var_name=_TASK_ID_VAR_NAME,
setup=setup,
header=header,
settings=cluster_settings,
)


Expand All @@ -200,31 +206,28 @@ async def launch(
if len(jobs) < 1:
return []

(
storage_root,
hostname,
user,
connect_kwargs,
) = config_lib.default().get_cluster_settings()
cluster_settings = config_lib.default().cluster_settings()

artifact = artifacts.create_artifact_store(
storage_root,
hostname=hostname,
user=user,
cluster_settings.storage_root,
hostname=cluster_settings.hostname,
user=cluster_settings.user,
project=config.project(),
connect_kwargs=connect_kwargs,
connect_kwargs=cluster_settings.ssh_config,
)

version = datetime.datetime.now().strftime("%Y%m%d.%H%M%S")
job_name = f"job-{version}"
job_script_dir = artifact.job_path(job_name)
job_script_content = create_job_script(jobs, job_name, job_script_dir)
job_script_content = create_job_script(
cluster_settings, jobs, job_name, job_script_dir
)

artifact.deploy_job_scripts(job_name, job_script_content)
job_script_path = os.path.join(job_script_dir, job_script.JOB_SCRIPT_NAME)

client = gridengine.Client(hostname, user)
console.log(f"Launching {len(jobs)} jobs on {hostname}")
client = gridengine.Client(cluster_settings.hostname, cluster_settings.user)
console.log(f"Launching {len(jobs)} jobs on {cluster_settings.hostname}")

console.log(f"Launch with command:\n qsub {job_script_path}")
group = client.launch(job_script_path)
Expand Down
8 changes: 5 additions & 3 deletions lxm3/xm_cluster/execution/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def _get_setup_cmds(


def create_job_script(
local_settings: config_lib.LocalSettings,
artifact: artifacts.Artifact,
jobs: List[xm.Job],
version: Optional[str] = None,
Expand Down Expand Up @@ -94,6 +95,7 @@ def create_job_script(
task_id_var_name=_TASK_ID_VAR_NAME,
setup=setup,
header=header,
settings=local_settings,
)


Expand All @@ -112,12 +114,12 @@ async def launch(config: config_lib.Config, jobs: List[xm.Job]):
if len(jobs) < 1:
return []

local_config = config.local_config()
local_config = config.local_settings()
artifact = artifacts.LocalArtifact(
local_config["storage"]["staging"], project=config.project()
local_config.storage_root, project=config.project()
)
version = datetime.datetime.now().strftime("%Y%m%d.%H%M%S")
job_script_content = create_job_script(artifact, jobs, version)
job_script_content = create_job_script(local_config, artifact, jobs, version)
job_name = f"job-{version}"

job_script_dir = artifact.job_path(job_name)
Expand Down
31 changes: 18 additions & 13 deletions lxm3/xm_cluster/execution/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,12 @@ def _get_setup_cmds(
return "\n".join(cmds)


def create_job_script(jobs: List[xm.Job], job_name, job_script_dir) -> str:
def create_job_script(
cluster_settings: config_lib.ClusterSettings,
jobs: List[xm.Job],
job_name,
job_script_dir,
) -> str:
executable = jobs[0].executable
executor = jobs[0].executor

Expand All @@ -143,6 +148,7 @@ def create_job_script(jobs: List[xm.Job], job_name, job_script_dir) -> str:
task_id_var_name=_TASK_ID_VAR_NAME,
setup=setup,
header=header,
settings=cluster_settings,
)


Expand All @@ -166,31 +172,30 @@ async def launch(config: config_lib.Config, jobs: List[xm.Job]) -> List[SlurmHan
"Only GridEngine executors are supported by the gridengine backend."
)

(
storage_root,
hostname,
user,
connect_kwargs,
) = config_lib.default().get_cluster_settings()
cluster_settings = config_lib.default().cluster_settings()

artifact = artifacts.create_artifact_store(
storage_root,
hostname=hostname,
user=user,
cluster_settings.storage_root,
hostname=cluster_settings.hostname,
user=cluster_settings.user,
project=config.project(),
connect_kwargs=connect_kwargs,
connect_kwargs=cluster_settings.ssh_config,
)

version = datetime.datetime.now().strftime("%Y%m%d.%H%M%S")
job_name = f"job-{version}"
job_script_dir = artifact.job_path(job_name)
job_script_content = create_job_script(jobs, job_name, job_script_dir)
job_script_content = create_job_script(
cluster_settings, jobs, job_name, job_script_dir
)

artifact.deploy_job_scripts(job_name, job_script_content)
job_script_path = os.path.join(job_script_dir, job_script.JOB_SCRIPT_NAME)

console.log(f"Launch with command:\n sbatch {job_script_path}")
client = slurm.Client(hostname=hostname, username=user)
client = slurm.Client(
hostname=cluster_settings.hostname, username=cluster_settings.user
)

job_id = client.launch(job_script_path)
common.write_job_id(artifact, job_script_path, str(job_id))
Expand Down
2 changes: 1 addition & 1 deletion lxm3/xm_cluster/executors.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class SingularityOptions(xm.ExecutorSpec):
lxm3 will do it for you.
"""

bind: Optional[Dict[str, str]] = None
bind: Dict[str, str] = attr.Factory(dict)
extra_options: Sequence[str] = attr.Factory(list)


Expand Down
Loading

0 comments on commit 2da5bd3

Please sign in to comment.