From 3eab36a3fce891268c40ec13a4cba1bdee6185e0 Mon Sep 17 00:00:00 2001 From: Bo Peng Date: Thu, 15 Feb 2024 23:11:47 +0000 Subject: [PATCH] Move sos_run_script to functions.py --- setup.py | 7 - src/sos/actions.py | 329 +---------------------------------- src/sos/functions.py | 337 ++++++++++++++++++++++++++++++++++++ src/sos/hosts.py | 3 +- src/sos/parser.py | 2 +- src/sos/runtime.py | 11 +- src/sos/section_analyzer.py | 3 +- src/sos/targets.py | 6 +- test/run_tests.py | 2 +- test/test_parser.py | 7 +- test/test_target.py | 78 ++------- 11 files changed, 374 insertions(+), 411 deletions(-) create mode 100644 src/sos/functions.py diff --git a/setup.py b/setup.py index 21a4fca3c..9aa23f62a 100644 --- a/setup.py +++ b/setup.py @@ -125,14 +125,7 @@ def run(self): [sos_actions] script = sos.actions:script sos_run = sos.actions:sos_run -fail_if = sos.actions:fail_if -warn_if = sos.actions:warn_if -stop_if = sos.actions:stop_if -done_if = sos.actions:done_if -skip_if = sos.actions:skip_if -download = sos.actions:download run = sos.actions:run - bash = sos.actions_bash:bash csh = sos.actions_bash:csh tcsh = sos.actions_bash:tcsh diff --git a/src/sos/actions.py b/src/sos/actions.py index f3c0c0de4..8b25cb345 100644 --- a/src/sos/actions.py +++ b/src/sos/actions.py @@ -4,47 +4,32 @@ # Distributed under the terms of the 3-clause BSD License. import copy -import gzip import os import shlex import shutil import subprocess import sys -import tarfile import tempfile import textwrap import time -import urllib -import urllib.error -import urllib.parse -import urllib.request import uuid -import zipfile from collections.abc import Sequence -from concurrent.futures import ProcessPoolExecutor from functools import wraps from typing import Any, Callable, Dict, List, Tuple, Union -from tqdm import tqdm as ProgressBar - from .controller import send_message_to_controller from .eval import interpolate from .messages import decode_msg, encode_msg from .parser import SoS_Script from .syntax import SOS_ACTION_OPTIONS from .targets import executable, file_target, path, paths, sos_targets -from .utils import (StopInputGroup, TerminateExecution, - TimeoutInterProcessLock, env, fileMD5, get_traceback, - load_config_files, short_repr, textMD5, transcribe) +from .utils import (TimeoutInterProcessLock, env, load_config_files, + short_repr, textMD5, transcribe) __all__ = [ "SoS_Action", "script", "sos_run", - "fail_if", - "warn_if", - "stop_if", - "download", "run", "perl", "report", @@ -129,8 +114,7 @@ def action_wrapper(*args, **kwargs): "shub", "oras", ): - env.logger.warning( - f"Container type {cty} might not be supported.") + env.logger.warning(f"Container type {cty} might not be supported.") elif engine is not None and engine != "local": raise ValueError(f"Only docker and singularity container engines are supported: {engine} specified") else: @@ -412,8 +396,9 @@ def run(self, **kwargs): else: raise RuntimeError(f"Unacceptable interpreter {self.interpreter}") - debug_script_path = os.path.dirname(os.path.abspath(kwargs["stderr"])) if ("stderr" in kwargs and kwargs["stderr"] is not False and - os.path.isdir(os.path.dirname(os.path.abspath(kwargs["stderr"])))) else env.exec_dir + debug_script_path = os.path.dirname(os.path.abspath(kwargs["stderr"])) if ( + "stderr" in kwargs and kwargs["stderr"] is not False and + os.path.isdir(os.path.dirname(os.path.abspath(kwargs["stderr"])))) else env.exec_dir debug_script_file = os.path.join( debug_script_path, f'{env.sos_dict["step_name"]}_{env.sos_dict["_index"]}_{str(uuid.uuid4())[:8]}{self.suffix}', @@ -769,313 +754,11 @@ def script(script, interpreter="", suffix="", args="", entrypoint="", **kwargs): return SoS_ExecuteScript(script, interpreter, suffix, args, entrypoint).run(**kwargs) -@SoS_Action(acceptable_args=["expr", "msg"]) -def fail_if(expr, msg=""): - """Raise an exception with `msg` if condition `expr` is False""" - if expr: - raise TerminateExecution(msg if msg else "error triggered by action fail_if") - return 0 - - -@SoS_Action(acceptable_args=["expr", "msg"]) -def warn_if(expr, msg=""): - """Yield an warning message `msg` if `expr` is False """ - if expr: - env.logger.warning(msg) - return 0 - - -@SoS_Action(acceptable_args=["expr", "msg", "no_output"]) -def stop_if(expr, msg="", no_output=False): - """Abort the execution of the current step or loop and yield - an warning message `msg` if `expr` is False""" - if expr: - raise StopInputGroup(msg=msg, keep_output=not no_output) - return 0 - - -@SoS_Action(acceptable_args=["expr", "msg"]) -def done_if(expr, msg=""): - """Assuming that output has already been generated and stop - executing the rest of the substep""" - if expr: - raise StopInputGroup(msg=msg, keep_output=True) - return 0 - - -@SoS_Action(acceptable_args=["expr", "msg", "no_output"]) -def skip_if(expr, msg=""): - """Skip the current substep and set _output to empty. Output - will be removed if already generated.""" - if expr: - raise StopInputGroup(msg=msg, keep_output=False) - return 0 - - # # download file with progress bar # -def downloadURL(URL, dest, decompress=False, index=None): - dest = os.path.abspath(os.path.expanduser(dest)) - dest_dir, filename = os.path.split(dest) - # - if not os.path.isdir(dest_dir): - os.makedirs(dest_dir, exist_ok=True) - if not os.path.isdir(dest_dir): - raise RuntimeError(f"Failed to create destination directory to download {URL}") - # - message = filename - if len(message) > 30: - message = message[:10] + "..." + message[-16:] - # - dest_tmp = dest + f".tmp_{os.getpid()}" - term_width = shutil.get_terminal_size((80, 20)).columns - try: - env.logger.debug(f"Download {URL} to {dest}") - sig = file_target(dest) - if os.path.isfile(dest): - prog = ProgressBar( - desc=message, - disable=env.verbosity <= 1, - position=index, - leave=True, - bar_format="{desc}", - total=10000000, - ) - target = file_target(dest) - if env.config["sig_mode"] == "build": - prog.set_description(message + ": \033[32m writing signature\033[0m") - prog.update() - target.write_sig() - prog.close() - return True - if env.config["sig_mode"] == "ignore": - prog.set_description(message + ": \033[32m use existing\033[0m") - prog.update() - prog.close() - return True - if env.config["sig_mode"] in ("default", "skip", "distributed"): - prog.update() - if sig.validate(): - prog.set_description(message + ": \033[32m Validated\033[0m") - prog.update() - prog.close() - return True - prog.set_description(message + ":\033[91m Signature mismatch\033[0m") - target.write_sig() - prog.update() - # - prog = ProgressBar( - desc=message, - disable=env.verbosity <= 1, - position=index, - leave=True, - bar_format="{desc}", - total=10000000, - ) - # - # Stop using pycurl because of libcurl version compatibility problems - # that happen so often and difficult to fix. Error message looks like - # - # Reason: Incompatible library version: pycurl.cpython-35m-darwin.so - # requires version 9.0.0 or later, but libcurl.4.dylib provides version 7.0.0 - # - # with open(dest_tmp, 'wb') as f: - # c = pycurl.Curl() - # c.setopt(pycurl.URL, str(URL)) - # c.setopt(pycurl.WRITEFUNCTION, f.write) - # c.setopt(pycurl.SSL_VERIFYPEER, False) - # c.setopt(pycurl.NOPROGRESS, False) - # c.setopt(pycurl.PROGRESSFUNCTION, prog.curlUpdate) - # c.perform() - # if c.getinfo(pycurl.HTTP_CODE) == 404: - # prog.set_description(message + ':\033[91m 404 Error {}\033[0m'.format(' '*(term_width - len(message) - 12))) - # try: - # os.remove(dest_tmp) - # except OSError: - # pass - # return False - with open(dest_tmp, "wb") as f: - try: - u = urllib.request.urlopen(str(URL)) - try: - file_size = int(u.getheader("Content-Length")) - prog = ProgressBar(total=file_size, desc=message, position=index, leave=False) - except Exception: - file_size = None - file_size_dl = 0 - block_sz = 8192 - while True: - buffer = u.read(block_sz) - if not buffer: - break - file_size_dl += len(buffer) - f.write(buffer) - prog.update(len(buffer)) - except urllib.error.HTTPError as e: - prog.set_description(message + f":\033[91m {e.code} Error\033[0m") - prog.update() - prog.close() - try: - os.remove(dest_tmp) - except OSError: - pass - return False - except Exception as e: - prog.set_description(message + f":\033[91m {e}\033[0m") - prog.update() - prog.close() - try: - os.remove(dest_tmp) - except OSError: - pass - return False - # - if os.path.isfile(dest): - os.remove(dest) - os.rename(dest_tmp, dest) - decompressed = 0 - if decompress: - if zipfile.is_zipfile(dest): - prog.set_description(message + ":\033[91m Decompressing\033[0m") - prog.update() - prog.close() - zfile = zipfile.ZipFile(dest) - zfile.extractall(dest_dir) - names = zfile.namelist() - for name in names: - if os.path.isdir(os.path.join(dest_dir, name)): - continue - if not os.path.isfile(os.path.join(dest_dir, name)): - return False - decompressed += 1 - elif tarfile.is_tarfile(dest): - prog.set_description(message + ":\033[91m Decompressing\033[0m") - prog.update() - prog.close() - with tarfile.open(dest, "r:*") as tar: - tar.extractall(dest_dir) - # only extract files - files = [x.name for x in tar.getmembers() if x.isfile()] - for name in files: - if not os.path.isfile(os.path.join(dest_dir, name)): - return False - decompressed += 1 - elif dest.endswith(".gz"): - prog.set_description(message + ":\033[91m Decompressing\033[0m") - prog.update() - prog.close() - decomp = dest[:-3] - with gzip.open(dest, "rb") as fin, open(decomp, "wb") as fout: - buffer = fin.read(100000) - while buffer: - fout.write(buffer) - buffer = fin.read(100000) - decompressed += 1 - decompress_msg = ("" if not decompressed else - f' ({decompressed} file{"" if decompressed <= 1 else "s"} decompressed)') - prog.set_description( - message + - f':\033[32m downloaded{decompress_msg} {" "*(term_width - len(message) - 13 - len(decompress_msg))}\033[0m') - prog.update() - prog.close() - # if a md5 file exists - # if downloaded files contains .md5 signature, use them to validate - # downloaded files. - if os.path.isfile(dest + ".md5"): - prog.set_description(message + ":\033[91m Verifying md5 signature\033[0m") - prog.update() - prog.close() - with open(dest + ".md5") as md5: - rec_md5 = md5.readline().split()[0].strip() - obs_md5 = fileMD5(dest, sig_type='full') - if rec_md5 != obs_md5: - prog.set_description(message + ":\033[91m MD5 signature mismatch\033[0m") - prog.update() - prog.close() - env.logger.warning( - f"md5 signature mismatch for downloaded file {filename[:-4]} (recorded {rec_md5}, observed {obs_md5})" - ) - prog.set_description(message + ":\033[91m MD5 signature verified\033[0m") - prog.update() - prog.close() - except Exception as e: - if env.verbosity > 2: - sys.stderr.write(get_traceback()) - env.logger.error(f"Failed to download: {e}") - return False - finally: - # if there is something wrong still remove temporary file - if os.path.isfile(dest_tmp): - os.remove(dest_tmp) - return os.path.isfile(dest) - - -@SoS_Action(acceptable_args=["URLs", "workdir", "dest_dir", "dest_file", "decompress", "max_jobs"]) -def download(URLs, dest_dir=".", dest_file=None, decompress=False, max_jobs=5): - """Download files from specified URL, which should be space, tab or - newline separated URLs. The files will be downloaded to specified destination. - Option "dest_dir" specify the destination directory, - and "dest_file" specify the output filename, which will otherwise be the same - specified in the URL. If `filename.md5` files are downloaded, they are used to - validate downloaded `filename`. If "decompress=True", compressed - files are decompressed. If `max_jobs` is given, a maximum of `max_jobs` - concurrent download jobs will be used for each domain. This restriction - applies to domain names and will be applied to multiple download - instances. - """ - if env.config["run_mode"] == "dryrun": - print(f"HINT: download\n{URLs}\n") - return None - if isinstance(URLs, str): - urls = [x.strip() for x in URLs.split() if x.strip()] - else: - urls = list(URLs) - - if not urls: - env.logger.debug(f"No download URL specified: {URLs}") - return - # - if dest_file is not None and len(urls) != 1: - raise RuntimeError("Only one URL is allowed if a destination file is specified.") - # - if dest_file is None: - filenames = [] - for idx, url in enumerate(urls): - token = urllib.parse.urlparse(url) - # if no scheme or netloc, the URL is not acceptable - if not all([getattr(token, qualifying_attr) for qualifying_attr in ("scheme", "netloc")]): - raise ValueError(f"Invalid URL {url}") - filename = os.path.split(token.path)[-1] - if not filename: - raise ValueError(f"Cannot determine destination file for {url}") - filenames.append(os.path.join(dest_dir, filename)) - else: - token = urllib.parse.urlparse(urls[0]) - if not all([getattr(token, qualifying_attr) for qualifying_attr in ("scheme", "netloc")]): - raise ValueError(f"Invalid URL {url}") - filenames = [dest_file] - # - succ = [(False, None) for x in urls] - with ProcessPoolExecutor(max_workers=max_jobs) as executor: - for idx, (url, filename) in enumerate(zip(urls, filenames)): - # if there is alot, start download - succ[idx] = executor.submit(downloadURL, url, filename, decompress, idx) - succ = [x.result() for x in succ] - - # for su, url in zip(succ, urls): - # if not su: - # env.logger.warning('Failed to download {}'.format(url)) - failed = [y for x, y in zip(succ, urls) if not x] - if failed: - if len(urls) == 1: - raise RuntimeError("Failed to download {urls[0]}") - raise RuntimeError(f"Failed to download {failed[0]} ({len(failed)} out of {len(urls)})") - return 0 - - @SoS_Action(acceptable_args=["script", "interpreter", "args", "entrypoint"]) def run(script, interpreter="", args="", entrypoint="", **kwargs): """Execute specified script using bash. This action accepts common action arguments such as diff --git a/src/sos/functions.py b/src/sos/functions.py new file mode 100644 index 000000000..3637525d8 --- /dev/null +++ b/src/sos/functions.py @@ -0,0 +1,337 @@ +import gzip +import tarfile +import urllib +import urllib.error +import urllib.parse +import urllib.request +import zipfile +from concurrent.futures import ProcessPoolExecutor + +import pkg_resources +from tqdm import tqdm as ProgressBar + +from .utils import (StopInputGroup, TerminateExecution, env, fileMD5, + get_traceback) + +g_action_map = {} + + +def _load_actions(): + global g_action_map # pylint: disable=global-variable-not-assigned + for _entrypoint in pkg_resources.iter_entry_points(group="sos_actions"): + # import actions from entry_points + # Grab the function that is the actual plugin. + _name = _entrypoint.name + try: + _plugin = _entrypoint.load() + g_action_map[_name] = _plugin + except Exception as e: + from .utils import get_logger + + # look for sos version requirement + get_logger().warning(f"Failed to load script running action {_entrypoint.name}: {e}") + + +def sos_run_script(action, script, *args, **kwargs): + if not g_action_map: + _load_actions() + try: + g_action_map[action](script, *args, **kwargs) + except KeyError as e: + raise RuntimeError(f'Undefined script running action {action}') from e + + +def stop_if(expr, msg="", no_output=False): + """Abort the execution of the current step or loop and yield + an warning message `msg` if `expr` is False""" + if expr: + raise StopInputGroup(msg=msg, keep_output=not no_output) + return 0 + + +def done_if(expr, msg=""): + """Assuming that output has already been generated and stop + executing the rest of the substep""" + if expr: + raise StopInputGroup(msg=msg, keep_output=True) + return 0 + + +def skip_if(expr, msg=""): + """Skip the current substep and set _output to empty. Output + will be removed if already generated.""" + if expr: + raise StopInputGroup(msg=msg, keep_output=False) + return 0 + + +def fail_if(expr, msg=""): + """Raise an exception with `msg` if condition `expr` is False""" + if expr: + raise TerminateExecution(msg if msg else "error triggered by action fail_if") + return 0 + + +def warn_if(expr, msg=""): + """Yield an warning message `msg` if `expr` is False """ + if expr: + env.logger.warning(msg) + return 0 + + +def downloadURL(URL, dest, decompress=False, index=None): + dest = os.path.abspath(os.path.expanduser(dest)) + dest_dir, filename = os.path.split(dest) + # + if not os.path.isdir(dest_dir): + os.makedirs(dest_dir, exist_ok=True) + if not os.path.isdir(dest_dir): + raise RuntimeError(f"Failed to create destination directory to download {URL}") + # + message = filename + if len(message) > 30: + message = message[:10] + "..." + message[-16:] + # + dest_tmp = dest + f".tmp_{os.getpid()}" + term_width = shutil.get_terminal_size((80, 20)).columns + try: + env.logger.debug(f"Download {URL} to {dest}") + sig = file_target(dest) + if os.path.isfile(dest): + prog = ProgressBar( + desc=message, + disable=env.verbosity <= 1, + position=index, + leave=True, + bar_format="{desc}", + total=10000000, + ) + target = file_target(dest) + if env.config["sig_mode"] == "build": + prog.set_description(message + ": \033[32m writing signature\033[0m") + prog.update() + target.write_sig() + prog.close() + return True + if env.config["sig_mode"] == "ignore": + prog.set_description(message + ": \033[32m use existing\033[0m") + prog.update() + prog.close() + return True + if env.config["sig_mode"] in ("default", "skip", "distributed"): + prog.update() + if sig.validate(): + prog.set_description(message + ": \033[32m Validated\033[0m") + prog.update() + prog.close() + return True + prog.set_description(message + ":\033[91m Signature mismatch\033[0m") + target.write_sig() + prog.update() + # + prog = ProgressBar( + desc=message, + disable=env.verbosity <= 1, + position=index, + leave=True, + bar_format="{desc}", + total=10000000, + ) + # + # Stop using pycurl because of libcurl version compatibility problems + # that happen so often and difficult to fix. Error message looks like + # + # Reason: Incompatible library version: pycurl.cpython-35m-darwin.so + # requires version 9.0.0 or later, but libcurl.4.dylib provides version 7.0.0 + # + # with open(dest_tmp, 'wb') as f: + # c = pycurl.Curl() + # c.setopt(pycurl.URL, str(URL)) + # c.setopt(pycurl.WRITEFUNCTION, f.write) + # c.setopt(pycurl.SSL_VERIFYPEER, False) + # c.setopt(pycurl.NOPROGRESS, False) + # c.setopt(pycurl.PROGRESSFUNCTION, prog.curlUpdate) + # c.perform() + # if c.getinfo(pycurl.HTTP_CODE) == 404: + # prog.set_description(message + ':\033[91m 404 Error {}\033[0m'.format(' '*(term_width - len(message) - 12))) + # try: + # os.remove(dest_tmp) + # except OSError: + # pass + # return False + with open(dest_tmp, "wb") as f: + try: + u = urllib.request.urlopen(str(URL)) + try: + file_size = int(u.getheader("Content-Length")) + prog = ProgressBar(total=file_size, desc=message, position=index, leave=False) + except Exception: + file_size = None + file_size_dl = 0 + block_sz = 8192 + while True: + buffer = u.read(block_sz) + if not buffer: + break + file_size_dl += len(buffer) + f.write(buffer) + prog.update(len(buffer)) + except urllib.error.HTTPError as e: + prog.set_description(message + f":\033[91m {e.code} Error\033[0m") + prog.update() + prog.close() + try: + os.remove(dest_tmp) + except OSError: + pass + return False + except Exception as e: + prog.set_description(message + f":\033[91m {e}\033[0m") + prog.update() + prog.close() + try: + os.remove(dest_tmp) + except OSError: + pass + return False + # + if os.path.isfile(dest): + os.remove(dest) + os.rename(dest_tmp, dest) + decompressed = 0 + if decompress: + if zipfile.is_zipfile(dest): + prog.set_description(message + ":\033[91m Decompressing\033[0m") + prog.update() + prog.close() + zfile = zipfile.ZipFile(dest) + zfile.extractall(dest_dir) + names = zfile.namelist() + for name in names: + if os.path.isdir(os.path.join(dest_dir, name)): + continue + if not os.path.isfile(os.path.join(dest_dir, name)): + return False + decompressed += 1 + elif tarfile.is_tarfile(dest): + prog.set_description(message + ":\033[91m Decompressing\033[0m") + prog.update() + prog.close() + with tarfile.open(dest, "r:*") as tar: + tar.extractall(dest_dir) + # only extract files + files = [x.name for x in tar.getmembers() if x.isfile()] + for name in files: + if not os.path.isfile(os.path.join(dest_dir, name)): + return False + decompressed += 1 + elif dest.endswith(".gz"): + prog.set_description(message + ":\033[91m Decompressing\033[0m") + prog.update() + prog.close() + decomp = dest[:-3] + with gzip.open(dest, "rb") as fin, open(decomp, "wb") as fout: + buffer = fin.read(100000) + while buffer: + fout.write(buffer) + buffer = fin.read(100000) + decompressed += 1 + decompress_msg = ("" if not decompressed else + f' ({decompressed} file{"" if decompressed <= 1 else "s"} decompressed)') + prog.set_description( + message + + f':\033[32m downloaded{decompress_msg} {" "*(term_width - len(message) - 13 - len(decompress_msg))}\033[0m') + prog.update() + prog.close() + # if a md5 file exists + # if downloaded files contains .md5 signature, use them to validate + # downloaded files. + if os.path.isfile(dest + ".md5"): + prog.set_description(message + ":\033[91m Verifying md5 signature\033[0m") + prog.update() + prog.close() + with open(dest + ".md5") as md5: + rec_md5 = md5.readline().split()[0].strip() + obs_md5 = fileMD5(dest, sig_type='full') + if rec_md5 != obs_md5: + prog.set_description(message + ":\033[91m MD5 signature mismatch\033[0m") + prog.update() + prog.close() + env.logger.warning( + f"md5 signature mismatch for downloaded file {filename[:-4]} (recorded {rec_md5}, observed {obs_md5})" + ) + prog.set_description(message + ":\033[91m MD5 signature verified\033[0m") + prog.update() + prog.close() + except Exception as e: + if env.verbosity > 2: + sys.stderr.write(get_traceback()) + env.logger.error(f"Failed to download: {e}") + return False + finally: + # if there is something wrong still remove temporary file + if os.path.isfile(dest_tmp): + os.remove(dest_tmp) + return os.path.isfile(dest) + + +def download(URLs, dest_dir=".", dest_file=None, decompress=False, max_jobs=5): + """Download files from specified URL, which should be space, tab or + newline separated URLs. The files will be downloaded to specified destination. + Option "dest_dir" specify the destination directory, + and "dest_file" specify the output filename, which will otherwise be the same + specified in the URL. If `filename.md5` files are downloaded, they are used to + validate downloaded `filename`. If "decompress=True", compressed + files are decompressed. If `max_jobs` is given, a maximum of `max_jobs` + concurrent download jobs will be used for each domain. This restriction + applies to domain names and will be applied to multiple download + instances. + """ + if env.config["run_mode"] == "dryrun": + print(f"HINT: download\n{URLs}\n") + return None + if isinstance(URLs, str): + urls = [x.strip() for x in URLs.split() if x.strip()] + else: + urls = list(URLs) + + if not urls: + env.logger.debug(f"No download URL specified: {URLs}") + return + # + if dest_file is not None and len(urls) != 1: + raise RuntimeError("Only one URL is allowed if a destination file is specified.") + # + if dest_file is None: + filenames = [] + for idx, url in enumerate(urls): + token = urllib.parse.urlparse(url) + # if no scheme or netloc, the URL is not acceptable + if not all([getattr(token, qualifying_attr) for qualifying_attr in ("scheme", "netloc")]): + raise ValueError(f"Invalid URL {url}") + filename = os.path.split(token.path)[-1] + if not filename: + raise ValueError(f"Cannot determine destination file for {url}") + filenames.append(os.path.join(dest_dir, filename)) + else: + token = urllib.parse.urlparse(urls[0]) + if not all([getattr(token, qualifying_attr) for qualifying_attr in ("scheme", "netloc")]): + raise ValueError(f"Invalid URL {url}") + filenames = [dest_file] + # + succ = [(False, None) for x in urls] + with ProcessPoolExecutor(max_workers=max_jobs) as executor: + for idx, (url, filename) in enumerate(zip(urls, filenames)): + # if there is alot, start download + succ[idx] = executor.submit(downloadURL, url, filename, decompress, idx) + succ = [x.result() for x in succ] + + # for su, url in zip(succ, urls): + # if not su: + # env.logger.warning('Failed to download {}'.format(url)) + failed = [y for x, y in zip(succ, urls) if not x] + if failed: + if len(urls) == 1: + raise RuntimeError("Failed to download {urls[0]}") + raise RuntimeError(f"Failed to download {failed[0]} ({len(failed)} out of {len(urls)})") + return 0 diff --git a/src/sos/hosts.py b/src/sos/hosts.py index ea530957c..028de1e05 100755 --- a/src/sos/hosts.py +++ b/src/sos/hosts.py @@ -23,7 +23,8 @@ from .targets import path, sos_targets from .task_engines import BackgroundProcess_TaskEngine from .tasks import TaskFile -from .utils import (env, expand_size, expand_time, format_HHMMSS, short_repr, textMD5) +from .utils import (env, expand_size, expand_time, format_HHMMSS, short_repr, + textMD5) from .workflow_engines import BackgroundProcess_WorkflowEngine # diff --git a/src/sos/parser.py b/src/sos/parser.py index f4706a271..7875b2a9e 100755 --- a/src/sos/parser.py +++ b/src/sos/parser.py @@ -490,7 +490,7 @@ def wrap_script(self) -> None: self._script = repr(self._script) self.statements[-1] = [ "!", - f'{self._action}({self._script}{(", " + opt) if opt else ""})\n', + f'sos_run_script("{self._action}", {self._script}{(", " + opt) if opt else ""})\n', ] self.values = [] self._action = None diff --git a/src/sos/runtime.py b/src/sos/runtime.py index 9646a954c..6acc53e43 100644 --- a/src/sos/runtime.py +++ b/src/sos/runtime.py @@ -4,13 +4,15 @@ # Distributed under the terms of the 3-clause BSD License. import pkg_resources +from .functions import (done_if, download, fail_if, skip_if, sos_run_script, + stop_if, warn_if) # backward compatibility #1337 from .pattern import expand_pattern from .targets import path, paths from .utils import get_output, sos_get_param # silent pyflakes -sos_get_param, get_output, path, paths, expand_pattern +sos_get_param, get_output, path, paths, expand_pattern, done_if, download, fail_if, skip_if, stop_if, warn_if def _load_group(group: str) -> None: @@ -41,13 +43,10 @@ def _load_group(group: str) -> None: continue if _name == "run": # this is critical so we print the warning - get_logger().warning( - f"Failed to load target {_entrypoint.name}: {e}") + get_logger().warning(f"Failed to load target {_entrypoint.name}: {e}") else: - get_logger().debug( - f"Failed to load target {_entrypoint.name}: {e}") + get_logger().debug(f"Failed to load target {_entrypoint.name}: {e}") _load_group("sos_targets") -_load_group("sos_actions") _load_group("sos_functions") diff --git a/src/sos/section_analyzer.py b/src/sos/section_analyzer.py index 3e7564536..cdd4c4d4c 100644 --- a/src/sos/section_analyzer.py +++ b/src/sos/section_analyzer.py @@ -11,8 +11,7 @@ from .executor_utils import __null_func__, prepare_env, strip_param_defs from .parser import SoS_Step from .syntax import SOS_TARGETS_OPTIONS -from .targets import (dynamic, file_target, named_output, sos_step, - sos_targets) +from .targets import dynamic, file_target, named_output, sos_step, sos_targets from .utils import env # imported for eval, assert to reduce warning diff --git a/src/sos/targets.py b/src/sos/targets.py index ad918c1af..3c33f3b10 100644 --- a/src/sos/targets.py +++ b/src/sos/targets.py @@ -21,10 +21,12 @@ import fasteners import pkg_resources -from .controller import (request_answer_from_controller, send_message_to_controller) +from .controller import (request_answer_from_controller, + send_message_to_controller) from .eval import get_config, interpolate from .pattern import extract_pattern -from .utils import (Error, env, fileMD5, objectMD5, pickleable, short_repr, stable_repr, textMD5) +from .utils import (Error, env, fileMD5, objectMD5, pickleable, short_repr, + stable_repr, textMD5) __all__ = ["dynamic", "executable", "env_variable", "sos_variable"] diff --git a/test/run_tests.py b/test/run_tests.py index 087407524..8ca992ee4 100755 --- a/test/run_tests.py +++ b/test/run_tests.py @@ -121,7 +121,7 @@ def test_failed(test_names, return_code): failed_tests = retried_failed_tests if failed_tests: - print(f'Failed tests (logged to {LOGFILE}):\n' + '\n'.join(failed_tests)) + print(f'\n\n{len(failed_tests)} failed tests (logged to {LOGFILE}):\n' + '\n'.join(failed_tests)) else: print(f'All {len(all_tests)} tests complete successfully.') sys.exit(0 if not failed_tests else 1) diff --git a/test/test_parser.py b/test/test_parser.py index ba0715f63..9cc284829 100644 --- a/test/test_parser.py +++ b/test/test_parser.py @@ -1473,7 +1473,7 @@ def test_cell(): def test_overwrite_keyword(clear_now_and_after): """Test overwrite sos keyword with user defined one.""" - clear_now_and_after("a.txt") + clear_now_and_after("a.txt", "b.txt") # execute_workflow(""" def run(script): @@ -1483,15 +1483,16 @@ def run(script): run: touch a.txt """) - assert not os.path.isfile("a.txt") + assert os.path.isfile("a.txt") # execute_workflow(""" parameter: run = 5 [1] run: - touch a.txt + touch b.txt """) + assert os.path.isfile("b.txt") def test_help_message(sample_workflow): diff --git a/test/test_target.py b/test/test_target.py index 09a0c4d19..30bc39b76 100644 --- a/test/test_target.py +++ b/test/test_target.py @@ -94,12 +94,7 @@ def test_target_set_get(): def test_target_group_by(): """Test new option group_by to sos_targets""" - res = sos_targets( - "e.txt", - "f.ext", - a=["a.txt", "b.txt"], - b=["c.txt", "d.txt"], - group_by=1) + res = sos_targets("e.txt", "f.ext", a=["a.txt", "b.txt"], b=["c.txt", "d.txt"], group_by=1) assert len(res.groups) == 6 assert res.labels == ["", "", "a", "a", "b", "b"] # @@ -110,10 +105,7 @@ def test_target_group_by(): def test_target_paired_with(): """Test paired_with targets with vars""" res = sos_targets( - "e.txt", - "f.ext", - a=["a.txt", "b.txt"], - b=["c.txt", "d.txt"], + "e.txt", "f.ext", a=["a.txt", "b.txt"], b=["c.txt", "d.txt"], group_by=1).paired_with("_name", ["e", "f", "a", "b", "c", "d"]) for i, n in enumerate(["e", "f", "a", "b", "c", "d"]): assert res[i]._name == n @@ -124,17 +116,13 @@ def test_target_paired_with(): # # test assert for length difference with pytest.raises(Exception): - sos_targets("e.txt", - "f.ext").paired_with("name", ["e", "f", "a", "b", "c", "d"]) + sos_targets("e.txt", "f.ext").paired_with("name", ["e", "f", "a", "b", "c", "d"]) def test_target_group_with(): """Test group_with targets with vars""" res = sos_targets( - "e.txt", - "f.ext", - a=["a.txt", "b.txt"], - b=["c.txt", "d.txt"], + "e.txt", "f.ext", a=["a.txt", "b.txt"], b=["c.txt", "d.txt"], group_by=2).group_with("name", ["a1", "a2", "a3"]) for i, n in enumerate(["a1", "a2", "a3"]): assert res.groups[i].name == n @@ -168,8 +156,7 @@ def test_group_with_with_no_output(): def test_merging_of_sos_targets(): """Test merging of multiple sos targets""" # merge 0 to 0 - res = sos_targets("a.txt", "b.txt", - sos_targets("c.txt", "d.txt", group_by=1)) + res = sos_targets("a.txt", "b.txt", sos_targets("c.txt", "d.txt", group_by=1)) assert len(res) == 4 assert len(res.groups) == 2 assert res.groups[0] == ["a.txt", "b.txt", "c.txt"] @@ -217,15 +204,12 @@ def test_target_format(): (sos_targets("a b.txt"), "x", ".txt"), ]: if isinstance(res, str): - assert interpolate( - f"{{target:{fmt}}}", globals(), - locals()) == res, "Interpolation of {}:{} should be {}".format( - target, fmt, res) + assert interpolate(f"{{target:{fmt}}}", globals(), + locals()) == res, "Interpolation of {}:{} should be {}".format(target, fmt, res) else: - assert interpolate(f"{{target:{fmt}}}", globals(), locals( - )) in res, "Interpolation of {}:{} should be one of {}".format( - target, fmt, res) + assert interpolate(f"{{target:{fmt}}}", globals(), + locals()) in res, "Interpolation of {}:{} should be one of {}".format(target, fmt, res) def test_iter_targets(): @@ -382,9 +366,7 @@ def test_depends_executable(): file_target("a.txt").unlink() -@pytest.mark.skipif( - sys.platform == "win32", - reason="Windows executable cannot be created with chmod.") +@pytest.mark.skipif(sys.platform == "win32", reason="Windows executable cannot be created with chmod.") def test_output_executable(clear_now_and_after): """Testing target executable.""" # change $PATH so that lls can be found at the current @@ -446,9 +428,7 @@ def test_depends_env_variable(): file_target("a.txt").unlink() -@pytest.mark.skipif( - sys.platform == "win32", - reason="Windows executable cannot be created with chmod.") +@pytest.mark.skipif(sys.platform == "win32", reason="Windows executable cannot be created with chmod.") def test_provides_executable(): """Testing provides executable target.""" # change $PATH so that lls can be found at the current @@ -499,7 +479,8 @@ def test_shared_var_in_paired_with(temp_factory): def test_shared_var_in_for_each(temp_factory, clear_now_and_after): temp_factory("1.txt", "2.txt") - clear_now_and_after("1.out", "2.out", "1.out2", "2.out2", '2.out_2.out2', '1.out_1.out2', '1.out_2.out2', '2.out_1.out2') + clear_now_and_after("1.out", "2.out", "1.out2", "2.out2", '2.out_2.out2', '1.out_1.out2', '1.out_2.out2', + '2.out_1.out2') script = SoS_Script(""" [work_1: shared = {'data': 'step_output'}] input: "1.txt", "2.txt", group_by = 'single', pattern = '{name}.{ext}' @@ -687,36 +668,3 @@ def test_temp_file(): assert sos_tempfile(file_target('a.txt')) == sos_tempfile(file_target('a.txt')) """) - - - -def test_named_path(): - """Test the use of option name of path""" - execute_workflow( - """ - import os - # windows might not have HOME - if 'HOME' in os.environ: - assert path('#home') == os.environ['HOME'] - assert 'home' in path.names() - assert 'home' in path.names('docker') - """, - options={ - "config_file": os.path.join(os.path.expanduser("~"), "docker.yml") - }, - ) - - - -@pytest.mark.skipif( - sys.platform == 'win32', reason='Graphviz not available under windows') -def test_to_named_path_path(): - execute_workflow( - """ - [10: shared="a"] - a = path('/root/xxx/whatever').to_named_path(host='docker') - """, - options={ - "config_file": os.path.join(os.path.expanduser("~"), "docker.yml"), - }) - assert env.sos_dict['a'] == '#home/xxx/whatever'