diff --git a/jarvis_util/introspect/system_info.py b/jarvis_util/introspect/system_info.py index bf2c15a..c26fe46 100644 --- a/jarvis_util/introspect/system_info.py +++ b/jarvis_util/introspect/system_info.py @@ -775,6 +775,7 @@ def find_storage(self, min_cap=None, min_avail=None, mount_res=None, + shared=None, df=None): """ Find a set of storage devices. @@ -790,6 +791,7 @@ def find_storage(self, :param min_cap: Remove devices with too little overall capacity :param min_avail: Remove devices with too little available space :param mount_res: A regex or list of regexes to match mount points + :param shared: Whether to search for devices which are shared :param df: The data frame to run this query :return: Dataframe """ @@ -829,6 +831,8 @@ def find_storage(self, if common and condense: df = df.groupby(['mount']).first().reset_index() # df = df.drop_columns('host') + if shared is not None: + df = df[lambda r: r['shared'] == shared] return df @staticmethod diff --git a/jarvis_util/jutil_manager.py b/jarvis_util/jutil_manager.py index c229b24..abeb214 100644 --- a/jarvis_util/jutil_manager.py +++ b/jarvis_util/jutil_manager.py @@ -25,4 +25,5 @@ def __init__(self): self.hide_output = False self.debug_mpi_exec = False self.debug_local_exec = False + self.debug_scp = False diff --git a/jarvis_util/shell/exec_info.py b/jarvis_util/shell/exec_info.py index a415bc1..94108fd 100644 --- a/jarvis_util/shell/exec_info.py +++ b/jarvis_util/shell/exec_info.py @@ -6,6 +6,7 @@ from enum import Enum from jarvis_util.util.hostfile import Hostfile +from jarvis_util.jutil_manager import JutilManager import os from abc import ABC, abstractmethod @@ -33,7 +34,7 @@ class ExecInfo: def __init__(self, exec_type=ExecType.LOCAL, nprocs=None, ppn=None, user=None, pkey=None, port=None, hostfile=None, hosts=None, env=None, - sleep_ms=0, sudo=False, cwd=None, + sleep_ms=0, sudo=False, sudoenv=True, cwd=None, collect_output=None, pipe_stdout=None, pipe_stderr=None, hide_output=None, exec_async=False, stdin=None): """ @@ -49,6 +50,7 @@ def __init__(self, exec_type=ExecType.LOCAL, nprocs=None, ppn=None, :param env: The environment variables to use for command. :param sleep_ms: Sleep for a period of time AFTER executing :param sudo: Execute command with root privilege. E.g., SSH, PSSH + :param sudoenv: Support environment preservation in sudo :param cwd: Set current working directory. E.g., SSH, PSSH :param collect_output: Collect program output in python buffer :param pipe_stdout: Pipe STDOUT into a file. (path string) @@ -71,6 +73,7 @@ def __init__(self, exec_type=ExecType.LOCAL, nprocs=None, ppn=None, self._set_env(env) self.cwd = cwd self.sudo = sudo + self.sudoenv = sudoenv self.sleep_ms = sleep_ms self.collect_output = collect_output self.pipe_stdout = pipe_stdout @@ -152,6 +155,7 @@ def __init__(self): self.exit_code = None self.stdout = '' self.stderr = '' + self.jutil = JutilManager.get_instance() def failed(self): return self.exit_code != 0 @@ -164,20 +168,31 @@ def set_exit_code(self): def wait(self): pass - def smash_cmd(self, cmds): + def smash_cmd(self, cmds, sudo, basic_env, sudoenv): """ Convert a list of commands into a single command for the shell to execute. :param cmds: A list of commands or a single command string + :param prefix: A prefix for each command + :param sudo: Whether or not root is required + :param basic_env: The environment to forward to the command + :param sudoenv: Whether sudo supports environment forwarding :return: """ - if isinstance(cmds, list): - return ' && '.join(cmds) - elif isinstance(cmds, str): - return cmds - else: - raise Exception('Command must be either list or string') + env = None + if sudo: + env = '' + if sudoenv: + env = [f'-E {key}=\"{val}\"' for key, val in + basic_env.items()] + env = ' '.join(env) + env = f'sudo {env}' + if not isinstance(cmds, (list, tuple)): + cmds = [cmds] + if env is not None: + cmds = [f'{env} {cmd}' for cmd in cmds] + return ' && '.join(cmds) def wait_list(self, nodes): for node in nodes: diff --git a/jarvis_util/shell/local_exec.py b/jarvis_util/shell/local_exec.py index c9eac96..cade8d5 100644 --- a/jarvis_util/shell/local_exec.py +++ b/jarvis_util/shell/local_exec.py @@ -27,8 +27,6 @@ def __init__(self, cmd, exec_info): """ super().__init__() - jutil = JutilManager.get_instance() - cmd = self.smash_cmd(cmd) # Managing console output and collection self.collect_output = exec_info.collect_output @@ -39,13 +37,13 @@ def __init__(self, cmd, exec_info): self.hide_output = exec_info.hide_output # pylint: disable=R1732 if self.collect_output is None: - self.collect_output = jutil.collect_output + self.collect_output = self.jutil.collect_output if self.pipe_stdout is not None: self.pipe_stdout_fp = open(self.pipe_stdout, 'wb') if self.pipe_stderr is not None: self.pipe_stderr_fp = open(self.pipe_stderr, 'wb') if self.hide_output is None: - self.hide_output = jutil.hide_output + self.hide_output = self.jutil.hide_output # pylint: enable=R1732 self.stdout = io.StringIO() self.stderr = io.StringIO() @@ -57,13 +55,13 @@ def __init__(self, cmd, exec_info): self.exit_code = 0 # Copy ENV + self.basic_env = exec_info.basic_env.copy() self.env = exec_info.env.copy() for key, val in os.environ.items(): if key not in self.env: self.env[key] = val # Managing command execution - self.cmd = cmd self.sudo = exec_info.sudo self.stdin = exec_info.stdin self.exec_async = exec_info.exec_async @@ -72,13 +70,15 @@ def __init__(self, cmd, exec_info): self.cwd = os.getcwd() else: self.cwd = exec_info.cwd - if jutil.debug_local_exec: + + # Create the command + cmd = self.smash_cmd(cmd, self.sudo, self.basic_env, exec_info.sudoenv) + self.cmd = cmd + if self.jutil.debug_local_exec: print(cmd) self._start_bash_processes() def _start_bash_processes(self): - if self.sudo: - self.cmd = f'sudo {self.cmd}' time.sleep(self.sleep_ms) # pylint: disable=R1732 self.proc = subprocess.Popen(self.cmd, diff --git a/jarvis_util/shell/pssh_exec.py b/jarvis_util/shell/pssh_exec.py index f2c97ca..7aa53b8 100644 --- a/jarvis_util/shell/pssh_exec.py +++ b/jarvis_util/shell/pssh_exec.py @@ -22,7 +22,6 @@ def __init__(self, cmd, exec_info): :param exec_info: Info needed to execute command with SSH """ super().__init__() - self.cmd = self.smash_cmd(cmd) self.exec_async = exec_info.exec_async self.hosts = exec_info.hostfile.hosts self.execs_ = [] diff --git a/jarvis_util/shell/scp.py b/jarvis_util/shell/scp.py index 8b39301..5e1df07 100644 --- a/jarvis_util/shell/scp.py +++ b/jarvis_util/shell/scp.py @@ -4,6 +4,7 @@ """ from .local_exec import LocalExec from .exec_info import Executable +from jarvis_util.jutil_manager import JutilManager class _Scp(LocalExec): @@ -22,12 +23,15 @@ def __init__(self, src_path, dst_path, exec_info): """ self.addr = exec_info.hostfile.hosts[0] + if self.addr == 'localhost' or self.addr == '127.0.0.1': + return self.src_path = src_path self.dst_path = dst_path self.user = exec_info.user self.pkey = exec_info.pkey self.port = exec_info.port self.sudo = exec_info.sudo + self.jutil = JutilManager.get_instance() super().__init__(self.rsync_cmd(src_path, dst_path), exec_info.mod(env=exec_info.basic_env)) @@ -47,6 +51,8 @@ def rsync_cmd(self, src_path, dst_path): else: lines.append(f'{self.addr}:{dst_path}') rsync_cmd = ' '.join(lines) + if self.jutil.debug_scp: + print(rsync_cmd) return rsync_cmd diff --git a/jarvis_util/shell/ssh_exec.py b/jarvis_util/shell/ssh_exec.py index 3b95447..e207971 100644 --- a/jarvis_util/shell/ssh_exec.py +++ b/jarvis_util/shell/ssh_exec.py @@ -19,13 +19,14 @@ def __init__(self, cmd, exec_info): :param exec_info: Info needed to execute command with SSH """ - cmd = self.smash_cmd(cmd) self.addr = exec_info.hostfile.hosts[0] self.user = exec_info.user self.pkey = exec_info.pkey self.port = exec_info.port self.sudo = exec_info.sudo self.ssh_env = exec_info.env + self.basic_env = exec_info.env + cmd = self.smash_cmd(cmd, self.sudo, self.basic_env, exec_info.sudoenv) if not exec_info.hostfile.is_local(): super().__init__(self.ssh_cmd(cmd), exec_info.mod(env=exec_info.basic_env)) @@ -50,7 +51,7 @@ def ssh_cmd(self, cmd): cmd_lines.append(f'{key}=\"{val}\"') cmd_lines.append(cmd) env_cmd = ' '.join(cmd_lines) - real_cmd = f'{ssh_cmd} \"{env_cmd}\"' + real_cmd = f'{ssh_cmd} {env_cmd}' return real_cmd diff --git a/jarvis_util/util/argparse.py b/jarvis_util/util/argparse.py index daa6566..2cc7dac 100644 --- a/jarvis_util/util/argparse.py +++ b/jarvis_util/util/argparse.py @@ -45,6 +45,7 @@ def __init__(self, args=None, exit_on_fail=True, **custom_info): self.menu = None self.menu_name = None self.kwargs = {} + self.real_kwargs = {} self.define_options() self._parse() @@ -190,6 +191,7 @@ def _parse(self): default_args = self.default_kwargs( list(self.menu['kw_opts'].values()) + self.menu['pos_opts']) default_args.update(self.kwargs) + self.real_kwargs = self.kwargs self.kwargs = default_args @staticmethod diff --git a/jarvis_util/util/hostfile.py b/jarvis_util/util/hostfile.py index 124d123..5d015de 100644 --- a/jarvis_util/util/hostfile.py +++ b/jarvis_util/util/hostfile.py @@ -193,14 +193,11 @@ def save(self, path): fp.write('\n'.join(self.all_hosts)) return self - def ip_list(self): - return self.hosts_ip - - def hostname_list(self): - return self.hosts + def list(self): + return [Hostfile(all_hosts=[host]) for host in self.hosts] def enumerate(self): - return enumerate(self.hosts) + return enumerate(self.list()) def host_str(self, sep=','): return sep.join(self.hosts)