From 33bf5968b73d1c674de2428baaf2ff360a37c8e9 Mon Sep 17 00:00:00 2001 From: John Chilton Date: Thu, 18 Jul 2024 11:59:31 -0400 Subject: [PATCH] More typing of test case parsing. --- lib/galaxy/tool_util/parser/interface.py | 3 +- lib/galaxy/tool_util/verify/_types.py | 14 +++++++ lib/galaxy/tool_util/verify/interactor.py | 49 +++++++++++++++-------- lib/galaxy/tool_util/verify/parse.py | 43 ++++++++++++-------- 4 files changed, 76 insertions(+), 33 deletions(-) create mode 100644 lib/galaxy/tool_util/verify/_types.py diff --git a/lib/galaxy/tool_util/parser/interface.py b/lib/galaxy/tool_util/parser/interface.py index 809fe5f0c387..092938836a6c 100644 --- a/lib/galaxy/tool_util/parser/interface.py +++ b/lib/galaxy/tool_util/parser/interface.py @@ -45,12 +45,13 @@ class AssertionDict(TypedDict): ToolSourceTestInputs = Any ToolSourceTestOutputs = Any +TestSourceTestOutputColllection = Any class ToolSourceTest(TypedDict): inputs: ToolSourceTestInputs outputs: ToolSourceTestOutputs - output_collections: List[Any] + output_collections: List[TestSourceTestOutputColllection] stdout: AssertionList stderr: AssertionList expect_exit_code: Optional[XmlInt] diff --git a/lib/galaxy/tool_util/verify/_types.py b/lib/galaxy/tool_util/verify/_types.py new file mode 100644 index 000000000000..e7362fafd7b8 --- /dev/null +++ b/lib/galaxy/tool_util/verify/_types.py @@ -0,0 +1,14 @@ +"""Types used by interactor and test case processor.""" + +from typing import ( + Any, + Dict, + List, + Tuple, +) + +ExtraFileInfoDictT = Dict[str, Any] +RequiredFileTuple = Tuple[str, ExtraFileInfoDictT] +RequiredFilesT = List[RequiredFileTuple] +RequiredDataTablesT = List[str] +RequiredLocFileT = List[str] diff --git a/lib/galaxy/tool_util/verify/interactor.py b/lib/galaxy/tool_util/verify/interactor.py index c31dba70175e..52b0218cedfb 100644 --- a/lib/galaxy/tool_util/verify/interactor.py +++ b/lib/galaxy/tool_util/verify/interactor.py @@ -39,6 +39,7 @@ AssertionList, TestCollectionDef, TestCollectionOutputDef, + TestSourceTestOutputColllection, ) from galaxy.util import requests from galaxy.util.bunch import Bunch @@ -47,6 +48,11 @@ parse_checksum_hash, ) from . import verify +from ._types import ( + RequiredDataTablesT, + RequiredFilesT, + RequiredLocFileT, +) from .asserts import verify_assertions from .wait import wait_on @@ -92,7 +98,7 @@ def __getitem__(self, item): class ValidToolTestDict(TypedDict): inputs: Any outputs: Any - output_collections: List[Dict[str, Any]] + output_collections: List[TestSourceTestOutputColllection] stdout: NotRequired[AssertionList] stderr: NotRequired[AssertionList] expect_exit_code: NotRequired[Optional[Union[str, int]]] @@ -102,9 +108,9 @@ class ValidToolTestDict(TypedDict): num_outputs: NotRequired[Optional[Union[str, int]]] command_line: NotRequired[AssertionList] command_version: NotRequired[AssertionList] - required_files: NotRequired[List[Any]] - required_data_tables: NotRequired[List[Any]] - required_loc_files: NotRequired[List[str]] + required_files: NotRequired[RequiredFilesT] + required_data_tables: NotRequired[RequiredDataTablesT] + required_loc_files: NotRequired[RequiredLocFileT] error: Literal[False] tool_id: str tool_version: str @@ -1661,7 +1667,7 @@ class ToolTestDescriptionDict(TypedDict): name: str inputs: Any outputs: Any - output_collections: List[Dict[str, Any]] + output_collections: List[TestSourceTestOutputColllection] stdout: Optional[AssertionList] stderr: Optional[AssertionList] expect_exit_code: Optional[int] @@ -1693,13 +1699,14 @@ class ToolTestDescription: stderr: Optional[AssertionList] command_line: Optional[AssertionList] command_version: Optional[AssertionList] - required_files: List[Any] - required_data_tables: List[Any] - required_loc_files: List[str] + required_files: RequiredFilesT + required_data_tables: RequiredDataTablesT + required_loc_files: RequiredLocFileT expect_exit_code: Optional[int] expect_failure: bool expect_test_failure: bool exception: Optional[str] + output_collections: List[TestCollectionOutputDef] def __init__(self, processed_test_dict: ToolTestDict): assert ( @@ -1708,14 +1715,31 @@ def __init__(self, processed_test_dict: ToolTestDict): test_index = processed_test_dict["test_index"] name = cast(str, processed_test_dict.get("name", f"Test-{test_index + 1}")) error_in_test_definition = processed_test_dict["error"] + num_outputs: Optional[int] = None if not error_in_test_definition: processed_test_dict = cast(ValidToolTestDict, processed_test_dict) maxseconds = int(processed_test_dict.get("maxseconds") or DEFAULT_TOOL_TEST_WAIT or 86400) output_collections = processed_test_dict.get("output_collections", []) + if "num_outputs" in processed_test_dict and processed_test_dict["num_outputs"]: + num_outputs = int(processed_test_dict["num_outputs"]) + self.required_files = processed_test_dict.get("required_files", []) + self.required_data_tables = processed_test_dict.get("required_data_tables", []) + self.required_loc_files = processed_test_dict.get("required_loc_files", []) + self.command_line = processed_test_dict.get("command_line", None) + self.command_version = processed_test_dict.get("command_version", None) + self.stdout = processed_test_dict.get("stdout", None) + self.stderr = processed_test_dict.get("stderr", None) else: processed_test_dict = cast(InvalidToolTestDict, processed_test_dict) maxseconds = DEFAULT_TOOL_TEST_WAIT output_collections = [] + self.required_files = [] + self.required_data_tables = [] + self.required_loc_files = [] + self.command_line = None + self.command_version = None + self.stdout = None + self.stderr = None self.test_index = test_index assert ( @@ -1725,9 +1749,6 @@ def __init__(self, processed_test_dict: ToolTestDict): self.tool_version = processed_test_dict.get("tool_version") self.name = name self.maxseconds = maxseconds - self.required_files = cast(List[Any], processed_test_dict.get("required_files", [])) - self.required_data_tables = cast(List[Any], processed_test_dict.get("required_data_tables", [])) - self.required_loc_files = cast(List[str], processed_test_dict.get("required_loc_files", [])) inputs = processed_test_dict.get("inputs", {}) loaded_inputs = {} @@ -1739,16 +1760,12 @@ def __init__(self, processed_test_dict: ToolTestDict): self.inputs = loaded_inputs self.outputs = processed_test_dict.get("outputs", []) - self.num_outputs = cast(Optional[int], processed_test_dict.get("num_outputs", None)) + self.num_outputs = num_outputs self.error = processed_test_dict.get("error", False) self.exception = cast(Optional[str], processed_test_dict.get("exception", None)) self.output_collections = [TestCollectionOutputDef.from_dict(d) for d in output_collections] - self.command_line = cast(Optional[AssertionList], processed_test_dict.get("command_line", None)) - self.command_version = cast(Optional[AssertionList], processed_test_dict.get("command_version", None)) - self.stdout = cast(Optional[AssertionList], processed_test_dict.get("stdout", None)) - self.stderr = cast(Optional[AssertionList], processed_test_dict.get("stderr", None)) self.expect_exit_code = cast(Optional[int], processed_test_dict.get("expect_exit_code", None)) self.expect_failure = cast(bool, processed_test_dict.get("expect_failure", False)) self.expect_test_failure = cast(bool, processed_test_dict.get("expect_test_failure", False)) diff --git a/lib/galaxy/tool_util/verify/parse.py b/lib/galaxy/tool_util/verify/parse.py index 74ae8f3a1196..b814a208a1aa 100644 --- a/lib/galaxy/tool_util/verify/parse.py +++ b/lib/galaxy/tool_util/verify/parse.py @@ -2,6 +2,7 @@ import os from typing import ( Any, + Dict, Iterable, List, Optional, @@ -31,12 +32,16 @@ string_as_bool_or_none, unicodify, ) +from ._types import ( + ExtraFileInfoDictT, + RequiredDataTablesT, + RequiredFilesT, + RequiredLocFileT, +) log = logging.getLogger(__name__) -RequiredFilesT = List[Tuple[str, dict]] -RequiredDataTablesT = List[str] -RequiredLocFileT = List[str] +AnyParamContext = Union["ParamContext", "RootParamContext"] def parse_tool_test_descriptions( @@ -127,7 +132,7 @@ def _process_raw_inputs( required_files: RequiredFilesT, required_data_tables: RequiredDataTablesT, required_loc_files: RequiredLocFileT, - parent_context: Optional[Union["ParamContext", "RootParamContext"]] = None, + parent_context: Optional[AnyParamContext] = None, ): """ Recursively expand flat list of inputs into "tree" form of flat list @@ -192,7 +197,7 @@ def _process_raw_inputs( elif input_type == "repeat": repeat_index = 0 while True: - context = ParamContext(name=name, index=repeat_index, parent_context=parent_context) + context = ParamContext(name=name, parent_context=parent_context, index=repeat_index) updated = False page_source = input_source.parse_nested_inputs_source() for r_value in page_source.parse_input_sources(): @@ -268,12 +273,12 @@ def input_sources(tool_source: ToolSource) -> List[InputSource]: class ParamContext: - def __init__(self, name, index=None, parent_context=None): + def __init__(self, name: str, parent_context: AnyParamContext, index: Optional[int] = None): self.parent_context = parent_context self.name = name self.index = None if index is None else int(index) - def for_state(self): + def for_state(self) -> str: name = self.name if self.index is None else "%s_%d" % (self.name, self.index) parent_for_state = self.parent_context.for_state() if parent_for_state: @@ -281,7 +286,7 @@ def for_state(self): else: return name - def __str__(self): + def __str__(self) -> str: return f"Context[for_state={self.for_state()}]" def param_names(self): @@ -295,14 +300,14 @@ def param_names(self): else: yield self.name - def extract_value(self, raw_inputs): + def extract_value(self, raw_inputs: ToolSourceTestInputs): for param_name in self.param_names(): value = self.__raw_param_found(param_name, raw_inputs) if value: return value return None - def __raw_param_found(self, param_name, raw_inputs): + def __raw_param_found(self, param_name: str, raw_inputs: ToolSourceTestInputs): index = None for i, raw_input_dict in enumerate(raw_inputs): if raw_input_dict["name"] == param_name: @@ -442,20 +447,26 @@ def matches_declared_value(case_value): return None -def _add_uploaded_dataset(name: str, value: Any, extra, input_parameter: InputSource, required_files: RequiredFilesT): +def _add_uploaded_dataset( + name: str, + value: Optional[str], + extra: ExtraFileInfoDictT, + input_parameter: InputSource, + required_files: RequiredFilesT, +) -> Optional[str]: if value is None: assert input_parameter.parse_optional(), f"{name} is not optional. You must provide a valid filename." return value return require_file(name, value, extra, required_files) -def require_file(name, value, extra, required_files): +def require_file(name: str, value: str, extra: ExtraFileInfoDictT, required_files: RequiredFilesT) -> str: if (value, extra) not in required_files: required_files.append((value, extra)) # these files will be uploaded - name_change = [att for att in extra.get("edit_attributes", []) if att.get("type") == "name"] - if name_change: - name_change = name_change[-1].get("value") # only the last name change really matters - value = name_change # change value for select to renamed uploaded file for e.g. composite dataset + name_changes = [att for att in extra.get("edit_attributes", []) if att.get("type") == "name"] + if name_changes: + name_change = name_changes[-1].get("value") # only the last name change really matters + value = str(name_change) # change value for select to renamed uploaded file for e.g. composite dataset else: for end in [".zip", ".gz"]: if value.endswith(end):