Skip to content

Commit

Permalink
Merge pull request #26 from lukemartinlogan/master
Browse files Browse the repository at this point in the history
Add real_kwargs. Improve sudo based on Ares
  • Loading branch information
lukemartinlogan authored Sep 22, 2023
2 parents 67884ab + 18bb668 commit 2028e51
Show file tree
Hide file tree
Showing 9 changed files with 50 additions and 25 deletions.
4 changes: 4 additions & 0 deletions jarvis_util/introspect/system_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,6 +775,7 @@ def find_storage(self,
min_cap=None,
min_avail=None,
mount_res=None,
shared=None,
df=None):
"""
Find a set of storage devices.
Expand All @@ -790,6 +791,7 @@ def find_storage(self,
:param min_cap: Remove devices with too little overall capacity
:param min_avail: Remove devices with too little available space
:param mount_res: A regex or list of regexes to match mount points
:param shared: Whether to search for devices which are shared
:param df: The data frame to run this query
:return: Dataframe
"""
Expand Down Expand Up @@ -829,6 +831,8 @@ def find_storage(self,
if common and condense:
df = df.groupby(['mount']).first().reset_index()
# df = df.drop_columns('host')
if shared is not None:
df = df[lambda r: r['shared'] == shared]
return df

@staticmethod
Expand Down
1 change: 1 addition & 0 deletions jarvis_util/jutil_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,5 @@ def __init__(self):
self.hide_output = False
self.debug_mpi_exec = False
self.debug_local_exec = False
self.debug_scp = False

31 changes: 23 additions & 8 deletions jarvis_util/shell/exec_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from enum import Enum
from jarvis_util.util.hostfile import Hostfile
from jarvis_util.jutil_manager import JutilManager
import os
from abc import ABC, abstractmethod

Expand Down Expand Up @@ -33,7 +34,7 @@ class ExecInfo:
def __init__(self, exec_type=ExecType.LOCAL, nprocs=None, ppn=None,
user=None, pkey=None, port=None,
hostfile=None, hosts=None, env=None,
sleep_ms=0, sudo=False, cwd=None,
sleep_ms=0, sudo=False, sudoenv=True, cwd=None,
collect_output=None, pipe_stdout=None, pipe_stderr=None,
hide_output=None, exec_async=False, stdin=None):
"""
Expand All @@ -49,6 +50,7 @@ def __init__(self, exec_type=ExecType.LOCAL, nprocs=None, ppn=None,
:param env: The environment variables to use for command.
:param sleep_ms: Sleep for a period of time AFTER executing
:param sudo: Execute command with root privilege. E.g., SSH, PSSH
:param sudoenv: Support environment preservation in sudo
:param cwd: Set current working directory. E.g., SSH, PSSH
:param collect_output: Collect program output in python buffer
:param pipe_stdout: Pipe STDOUT into a file. (path string)
Expand All @@ -71,6 +73,7 @@ def __init__(self, exec_type=ExecType.LOCAL, nprocs=None, ppn=None,
self._set_env(env)
self.cwd = cwd
self.sudo = sudo
self.sudoenv = sudoenv
self.sleep_ms = sleep_ms
self.collect_output = collect_output
self.pipe_stdout = pipe_stdout
Expand Down Expand Up @@ -152,6 +155,7 @@ def __init__(self):
self.exit_code = None
self.stdout = ''
self.stderr = ''
self.jutil = JutilManager.get_instance()

def failed(self):
return self.exit_code != 0
Expand All @@ -164,20 +168,31 @@ def set_exit_code(self):
def wait(self):
pass

def smash_cmd(self, cmds):
def smash_cmd(self, cmds, sudo, basic_env, sudoenv):
"""
Convert a list of commands into a single command for the shell
to execute.
:param cmds: A list of commands or a single command string
:param prefix: A prefix for each command
:param sudo: Whether or not root is required
:param basic_env: The environment to forward to the command
:param sudoenv: Whether sudo supports environment forwarding
:return:
"""
if isinstance(cmds, list):
return ' && '.join(cmds)
elif isinstance(cmds, str):
return cmds
else:
raise Exception('Command must be either list or string')
env = None
if sudo:
env = ''
if sudoenv:
env = [f'-E {key}=\"{val}\"' for key, val in
basic_env.items()]
env = ' '.join(env)
env = f'sudo {env}'
if not isinstance(cmds, (list, tuple)):
cmds = [cmds]
if env is not None:
cmds = [f'{env} {cmd}' for cmd in cmds]
return ' && '.join(cmds)

def wait_list(self, nodes):
for node in nodes:
Expand Down
16 changes: 8 additions & 8 deletions jarvis_util/shell/local_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ def __init__(self, cmd, exec_info):
"""

super().__init__()
jutil = JutilManager.get_instance()
cmd = self.smash_cmd(cmd)

# Managing console output and collection
self.collect_output = exec_info.collect_output
Expand All @@ -39,13 +37,13 @@ def __init__(self, cmd, exec_info):
self.hide_output = exec_info.hide_output
# pylint: disable=R1732
if self.collect_output is None:
self.collect_output = jutil.collect_output
self.collect_output = self.jutil.collect_output
if self.pipe_stdout is not None:
self.pipe_stdout_fp = open(self.pipe_stdout, 'wb')
if self.pipe_stderr is not None:
self.pipe_stderr_fp = open(self.pipe_stderr, 'wb')
if self.hide_output is None:
self.hide_output = jutil.hide_output
self.hide_output = self.jutil.hide_output
# pylint: enable=R1732
self.stdout = io.StringIO()
self.stderr = io.StringIO()
Expand All @@ -57,13 +55,13 @@ def __init__(self, cmd, exec_info):
self.exit_code = 0

# Copy ENV
self.basic_env = exec_info.basic_env.copy()
self.env = exec_info.env.copy()
for key, val in os.environ.items():
if key not in self.env:
self.env[key] = val

# Managing command execution
self.cmd = cmd
self.sudo = exec_info.sudo
self.stdin = exec_info.stdin
self.exec_async = exec_info.exec_async
Expand All @@ -72,13 +70,15 @@ def __init__(self, cmd, exec_info):
self.cwd = os.getcwd()
else:
self.cwd = exec_info.cwd
if jutil.debug_local_exec:

# Create the command
cmd = self.smash_cmd(cmd, self.sudo, self.basic_env, exec_info.sudoenv)
self.cmd = cmd
if self.jutil.debug_local_exec:
print(cmd)
self._start_bash_processes()

def _start_bash_processes(self):
if self.sudo:
self.cmd = f'sudo {self.cmd}'
time.sleep(self.sleep_ms)
# pylint: disable=R1732
self.proc = subprocess.Popen(self.cmd,
Expand Down
1 change: 0 additions & 1 deletion jarvis_util/shell/pssh_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def __init__(self, cmd, exec_info):
:param exec_info: Info needed to execute command with SSH
"""
super().__init__()
self.cmd = self.smash_cmd(cmd)
self.exec_async = exec_info.exec_async
self.hosts = exec_info.hostfile.hosts
self.execs_ = []
Expand Down
6 changes: 6 additions & 0 deletions jarvis_util/shell/scp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""
from .local_exec import LocalExec
from .exec_info import Executable
from jarvis_util.jutil_manager import JutilManager


class _Scp(LocalExec):
Expand All @@ -22,12 +23,15 @@ def __init__(self, src_path, dst_path, exec_info):
"""

self.addr = exec_info.hostfile.hosts[0]
if self.addr == 'localhost' or self.addr == '127.0.0.1':
return
self.src_path = src_path
self.dst_path = dst_path
self.user = exec_info.user
self.pkey = exec_info.pkey
self.port = exec_info.port
self.sudo = exec_info.sudo
self.jutil = JutilManager.get_instance()
super().__init__(self.rsync_cmd(src_path, dst_path),
exec_info.mod(env=exec_info.basic_env))

Expand All @@ -47,6 +51,8 @@ def rsync_cmd(self, src_path, dst_path):
else:
lines.append(f'{self.addr}:{dst_path}')
rsync_cmd = ' '.join(lines)
if self.jutil.debug_scp:
print(rsync_cmd)
return rsync_cmd


Expand Down
5 changes: 3 additions & 2 deletions jarvis_util/shell/ssh_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@ def __init__(self, cmd, exec_info):
:param exec_info: Info needed to execute command with SSH
"""

cmd = self.smash_cmd(cmd)
self.addr = exec_info.hostfile.hosts[0]
self.user = exec_info.user
self.pkey = exec_info.pkey
self.port = exec_info.port
self.sudo = exec_info.sudo
self.ssh_env = exec_info.env
self.basic_env = exec_info.env
cmd = self.smash_cmd(cmd, self.sudo, self.basic_env, exec_info.sudoenv)
if not exec_info.hostfile.is_local():
super().__init__(self.ssh_cmd(cmd),
exec_info.mod(env=exec_info.basic_env))
Expand All @@ -50,7 +51,7 @@ def ssh_cmd(self, cmd):
cmd_lines.append(f'{key}=\"{val}\"')
cmd_lines.append(cmd)
env_cmd = ' '.join(cmd_lines)
real_cmd = f'{ssh_cmd} \"{env_cmd}\"'
real_cmd = f'{ssh_cmd} {env_cmd}'
return real_cmd


Expand Down
2 changes: 2 additions & 0 deletions jarvis_util/util/argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(self, args=None, exit_on_fail=True, **custom_info):
self.menu = None
self.menu_name = None
self.kwargs = {}
self.real_kwargs = {}
self.define_options()
self._parse()

Expand Down Expand Up @@ -190,6 +191,7 @@ def _parse(self):
default_args = self.default_kwargs(
list(self.menu['kw_opts'].values()) + self.menu['pos_opts'])
default_args.update(self.kwargs)
self.real_kwargs = self.kwargs
self.kwargs = default_args

@staticmethod
Expand Down
9 changes: 3 additions & 6 deletions jarvis_util/util/hostfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,14 +193,11 @@ def save(self, path):
fp.write('\n'.join(self.all_hosts))
return self

def ip_list(self):
return self.hosts_ip

def hostname_list(self):
return self.hosts
def list(self):
return [Hostfile(all_hosts=[host]) for host in self.hosts]

def enumerate(self):
return enumerate(self.hosts)
return enumerate(self.list())

def host_str(self, sep=','):
return sep.join(self.hosts)
Expand Down

0 comments on commit 2028e51

Please sign in to comment.