Skip to content

Commit

Permalink
More typing of test case parsing.
Browse files Browse the repository at this point in the history
  • Loading branch information
jmchilton committed Jul 18, 2024
1 parent 071c376 commit 33bf596
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 33 deletions.
3 changes: 2 additions & 1 deletion lib/galaxy/tool_util/parser/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,13 @@ class AssertionDict(TypedDict):

ToolSourceTestInputs = Any
ToolSourceTestOutputs = Any
TestSourceTestOutputColllection = Any


class ToolSourceTest(TypedDict):
inputs: ToolSourceTestInputs
outputs: ToolSourceTestOutputs
output_collections: List[Any]
output_collections: List[TestSourceTestOutputColllection]
stdout: AssertionList
stderr: AssertionList
expect_exit_code: Optional[XmlInt]
Expand Down
14 changes: 14 additions & 0 deletions lib/galaxy/tool_util/verify/_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""Types used by interactor and test case processor."""

from typing import (
Any,
Dict,
List,
Tuple,
)

ExtraFileInfoDictT = Dict[str, Any]
RequiredFileTuple = Tuple[str, ExtraFileInfoDictT]
RequiredFilesT = List[RequiredFileTuple]
RequiredDataTablesT = List[str]
RequiredLocFileT = List[str]
49 changes: 33 additions & 16 deletions lib/galaxy/tool_util/verify/interactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
AssertionList,
TestCollectionDef,
TestCollectionOutputDef,
TestSourceTestOutputColllection,
)
from galaxy.util import requests
from galaxy.util.bunch import Bunch
Expand All @@ -47,6 +48,11 @@
parse_checksum_hash,
)
from . import verify
from ._types import (
RequiredDataTablesT,
RequiredFilesT,
RequiredLocFileT,
)
from .asserts import verify_assertions
from .wait import wait_on

Expand Down Expand Up @@ -92,7 +98,7 @@ def __getitem__(self, item):
class ValidToolTestDict(TypedDict):
inputs: Any
outputs: Any
output_collections: List[Dict[str, Any]]
output_collections: List[TestSourceTestOutputColllection]
stdout: NotRequired[AssertionList]
stderr: NotRequired[AssertionList]
expect_exit_code: NotRequired[Optional[Union[str, int]]]
Expand All @@ -102,9 +108,9 @@ class ValidToolTestDict(TypedDict):
num_outputs: NotRequired[Optional[Union[str, int]]]
command_line: NotRequired[AssertionList]
command_version: NotRequired[AssertionList]
required_files: NotRequired[List[Any]]
required_data_tables: NotRequired[List[Any]]
required_loc_files: NotRequired[List[str]]
required_files: NotRequired[RequiredFilesT]
required_data_tables: NotRequired[RequiredDataTablesT]
required_loc_files: NotRequired[RequiredLocFileT]
error: Literal[False]
tool_id: str
tool_version: str
Expand Down Expand Up @@ -1661,7 +1667,7 @@ class ToolTestDescriptionDict(TypedDict):
name: str
inputs: Any
outputs: Any
output_collections: List[Dict[str, Any]]
output_collections: List[TestSourceTestOutputColllection]
stdout: Optional[AssertionList]
stderr: Optional[AssertionList]
expect_exit_code: Optional[int]
Expand Down Expand Up @@ -1693,13 +1699,14 @@ class ToolTestDescription:
stderr: Optional[AssertionList]
command_line: Optional[AssertionList]
command_version: Optional[AssertionList]
required_files: List[Any]
required_data_tables: List[Any]
required_loc_files: List[str]
required_files: RequiredFilesT
required_data_tables: RequiredDataTablesT
required_loc_files: RequiredLocFileT
expect_exit_code: Optional[int]
expect_failure: bool
expect_test_failure: bool
exception: Optional[str]
output_collections: List[TestCollectionOutputDef]

def __init__(self, processed_test_dict: ToolTestDict):
assert (
Expand All @@ -1708,14 +1715,31 @@ def __init__(self, processed_test_dict: ToolTestDict):
test_index = processed_test_dict["test_index"]
name = cast(str, processed_test_dict.get("name", f"Test-{test_index + 1}"))
error_in_test_definition = processed_test_dict["error"]
num_outputs: Optional[int] = None
if not error_in_test_definition:
processed_test_dict = cast(ValidToolTestDict, processed_test_dict)
maxseconds = int(processed_test_dict.get("maxseconds") or DEFAULT_TOOL_TEST_WAIT or 86400)
output_collections = processed_test_dict.get("output_collections", [])
if "num_outputs" in processed_test_dict and processed_test_dict["num_outputs"]:
num_outputs = int(processed_test_dict["num_outputs"])
self.required_files = processed_test_dict.get("required_files", [])
self.required_data_tables = processed_test_dict.get("required_data_tables", [])
self.required_loc_files = processed_test_dict.get("required_loc_files", [])
self.command_line = processed_test_dict.get("command_line", None)
self.command_version = processed_test_dict.get("command_version", None)
self.stdout = processed_test_dict.get("stdout", None)
self.stderr = processed_test_dict.get("stderr", None)
else:
processed_test_dict = cast(InvalidToolTestDict, processed_test_dict)
maxseconds = DEFAULT_TOOL_TEST_WAIT
output_collections = []
self.required_files = []
self.required_data_tables = []
self.required_loc_files = []
self.command_line = None
self.command_version = None
self.stdout = None
self.stderr = None

self.test_index = test_index
assert (
Expand All @@ -1725,9 +1749,6 @@ def __init__(self, processed_test_dict: ToolTestDict):
self.tool_version = processed_test_dict.get("tool_version")
self.name = name
self.maxseconds = maxseconds
self.required_files = cast(List[Any], processed_test_dict.get("required_files", []))
self.required_data_tables = cast(List[Any], processed_test_dict.get("required_data_tables", []))
self.required_loc_files = cast(List[str], processed_test_dict.get("required_loc_files", []))

inputs = processed_test_dict.get("inputs", {})
loaded_inputs = {}
Expand All @@ -1739,16 +1760,12 @@ def __init__(self, processed_test_dict: ToolTestDict):

self.inputs = loaded_inputs
self.outputs = processed_test_dict.get("outputs", [])
self.num_outputs = cast(Optional[int], processed_test_dict.get("num_outputs", None))
self.num_outputs = num_outputs

self.error = processed_test_dict.get("error", False)
self.exception = cast(Optional[str], processed_test_dict.get("exception", None))

self.output_collections = [TestCollectionOutputDef.from_dict(d) for d in output_collections]
self.command_line = cast(Optional[AssertionList], processed_test_dict.get("command_line", None))
self.command_version = cast(Optional[AssertionList], processed_test_dict.get("command_version", None))
self.stdout = cast(Optional[AssertionList], processed_test_dict.get("stdout", None))
self.stderr = cast(Optional[AssertionList], processed_test_dict.get("stderr", None))
self.expect_exit_code = cast(Optional[int], processed_test_dict.get("expect_exit_code", None))
self.expect_failure = cast(bool, processed_test_dict.get("expect_failure", False))
self.expect_test_failure = cast(bool, processed_test_dict.get("expect_test_failure", False))
Expand Down
43 changes: 27 additions & 16 deletions lib/galaxy/tool_util/verify/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
from typing import (
Any,
Dict,
Iterable,
List,
Optional,
Expand Down Expand Up @@ -31,12 +32,16 @@
string_as_bool_or_none,
unicodify,
)
from ._types import (
ExtraFileInfoDictT,
RequiredDataTablesT,
RequiredFilesT,
RequiredLocFileT,
)

log = logging.getLogger(__name__)

RequiredFilesT = List[Tuple[str, dict]]
RequiredDataTablesT = List[str]
RequiredLocFileT = List[str]
AnyParamContext = Union["ParamContext", "RootParamContext"]


def parse_tool_test_descriptions(
Expand Down Expand Up @@ -127,7 +132,7 @@ def _process_raw_inputs(
required_files: RequiredFilesT,
required_data_tables: RequiredDataTablesT,
required_loc_files: RequiredLocFileT,
parent_context: Optional[Union["ParamContext", "RootParamContext"]] = None,
parent_context: Optional[AnyParamContext] = None,
):
"""
Recursively expand flat list of inputs into "tree" form of flat list
Expand Down Expand Up @@ -192,7 +197,7 @@ def _process_raw_inputs(
elif input_type == "repeat":
repeat_index = 0
while True:
context = ParamContext(name=name, index=repeat_index, parent_context=parent_context)
context = ParamContext(name=name, parent_context=parent_context, index=repeat_index)
updated = False
page_source = input_source.parse_nested_inputs_source()
for r_value in page_source.parse_input_sources():
Expand Down Expand Up @@ -268,20 +273,20 @@ def input_sources(tool_source: ToolSource) -> List[InputSource]:


class ParamContext:
def __init__(self, name, index=None, parent_context=None):
def __init__(self, name: str, parent_context: AnyParamContext, index: Optional[int] = None):
self.parent_context = parent_context
self.name = name
self.index = None if index is None else int(index)

def for_state(self):
def for_state(self) -> str:
name = self.name if self.index is None else "%s_%d" % (self.name, self.index)
parent_for_state = self.parent_context.for_state()
if parent_for_state:
return f"{parent_for_state}|{name}"
else:
return name

def __str__(self):
def __str__(self) -> str:
return f"Context[for_state={self.for_state()}]"

def param_names(self):
Expand All @@ -295,14 +300,14 @@ def param_names(self):
else:
yield self.name

def extract_value(self, raw_inputs):
def extract_value(self, raw_inputs: ToolSourceTestInputs):
for param_name in self.param_names():
value = self.__raw_param_found(param_name, raw_inputs)
if value:
return value
return None

def __raw_param_found(self, param_name, raw_inputs):
def __raw_param_found(self, param_name: str, raw_inputs: ToolSourceTestInputs):
index = None
for i, raw_input_dict in enumerate(raw_inputs):
if raw_input_dict["name"] == param_name:
Expand Down Expand Up @@ -442,20 +447,26 @@ def matches_declared_value(case_value):
return None


def _add_uploaded_dataset(name: str, value: Any, extra, input_parameter: InputSource, required_files: RequiredFilesT):
def _add_uploaded_dataset(
name: str,
value: Optional[str],
extra: ExtraFileInfoDictT,
input_parameter: InputSource,
required_files: RequiredFilesT,
) -> Optional[str]:
if value is None:
assert input_parameter.parse_optional(), f"{name} is not optional. You must provide a valid filename."
return value
return require_file(name, value, extra, required_files)


def require_file(name, value, extra, required_files):
def require_file(name: str, value: str, extra: ExtraFileInfoDictT, required_files: RequiredFilesT) -> str:
if (value, extra) not in required_files:
required_files.append((value, extra)) # these files will be uploaded
name_change = [att for att in extra.get("edit_attributes", []) if att.get("type") == "name"]
if name_change:
name_change = name_change[-1].get("value") # only the last name change really matters
value = name_change # change value for select to renamed uploaded file for e.g. composite dataset
name_changes = [att for att in extra.get("edit_attributes", []) if att.get("type") == "name"]
if name_changes:
name_change = name_changes[-1].get("value") # only the last name change really matters
value = str(name_change) # change value for select to renamed uploaded file for e.g. composite dataset
else:
for end in [".zip", ".gz"]:
if value.endswith(end):
Expand Down

0 comments on commit 33bf596

Please sign in to comment.