diff --git a/lib/galaxy/tool_util/parser/interface.py b/lib/galaxy/tool_util/parser/interface.py index 89b17698353e..3c2ca3abb64e 100644 --- a/lib/galaxy/tool_util/parser/interface.py +++ b/lib/galaxy/tool_util/parser/interface.py @@ -19,6 +19,7 @@ import packaging.version from pydantic import BaseModel from typing_extensions import ( + Literal, NotRequired, TypedDict, ) @@ -549,6 +550,19 @@ def parse_input_sources(self) -> List[InputSource]: """Return a list of InputSource objects.""" +class TestCollectionDefElementDict(TypedDict): + element_identifier: str + element_definition: Union["TestCollectionDefDict", "ToolSourceTestInput"] + + +class TestCollectionDefDict(TypedDict): + model_class: Literal["TestCollectionDef"] + attributes: Dict[str, Any] + collection_type: str + elements: List[TestCollectionDefElementDict] + name: str + + class TestCollectionDef: __test__ = False # Prevent pytest from discovering this class (issue #12071) @@ -558,30 +572,7 @@ def __init__(self, attrib, name, collection_type, elements): self.elements = elements self.name = name - @staticmethod - def from_xml(elem, parse_param_elem): - elements = [] - attrib = dict(elem.attrib) - collection_type = attrib["type"] - name = attrib.get("name", "Unnamed Collection") - for element in elem.findall("element"): - element_attrib = dict(element.attrib) - element_identifier = element_attrib["name"] - nested_collection_elem = element.find("collection") - if nested_collection_elem is not None: - element_definition = TestCollectionDef.from_xml(nested_collection_elem, parse_param_elem) - else: - element_definition = parse_param_elem(element) - elements.append({"element_identifier": element_identifier, "element_definition": element_definition}) - - return TestCollectionDef( - attrib=attrib, - collection_type=collection_type, - elements=elements, - name=name, - ) - - def to_dict(self): + def to_dict(self) -> TestCollectionDefDict: def element_to_dict(element_dict): element_identifier, element_def = element_dict["element_identifier"], element_dict["element_definition"] if isinstance(element_def, TestCollectionDef): @@ -600,7 +591,7 @@ def element_to_dict(element_dict): } @staticmethod - def from_dict(as_dict): + def from_dict(as_dict: TestCollectionDefDict): assert as_dict["model_class"] == "TestCollectionDef" def element_from_dict(element_dict): diff --git a/lib/galaxy/tool_util/parser/xml.py b/lib/galaxy/tool_util/parser/xml.py index ef9b8e5d2b48..c2055c7bc612 100644 --- a/lib/galaxy/tool_util/parser/xml.py +++ b/lib/galaxy/tool_util/parser/xml.py @@ -45,8 +45,9 @@ PageSource, PagesSource, RequiredFiles, - TestCollectionDef, TestCollectionOutputDef, + TestCollectionDefDict, + TestCollectionDefElementDict, ToolSource, ToolSourceTest, ToolSourceTestInput, @@ -790,7 +791,7 @@ def __parse_output_collection_elem(output_collection_elem, profile=None): def __parse_element_tests(parent_element, profile=None): element_tests = {} for idx, element in enumerate(parent_element.findall("element")): - element_attrib = dict(element.attrib) + element_attrib: dict = dict(element.attrib) identifier = element_attrib.pop("name", None) if identifier is None: raise Exception("Test primary dataset does not have a 'identifier'") @@ -999,6 +1000,30 @@ def __parse_inputs_elems(test_elem, i) -> ToolSourceTestInputs: return raw_inputs +def _test_collection_def_dict(elem: Element) -> TestCollectionDefDict: + elements: TestCollectionDefElementDict = [] + attrib: dict = dict(elem.attrib) + collection_type = attrib["type"] + name = attrib.get("name", "Unnamed Collection") + for element in elem.findall("element"): + element_attrib = dict(element.attrib) + element_identifier = element_attrib["name"] + nested_collection_elem = element.find("collection") + if nested_collection_elem is not None: + element_definition = _test_collection_def_dict(nested_collection_elem) + else: + element_definition = __parse_param_elem(element) + elements.append({"element_identifier": element_identifier, "element_definition": element_definition}) + + return TestCollectionDefDict( + model_class="TestCollectionDef", + attributes=attrib, + collection_type=collection_type, + elements=elements, + name=name, + ) + + def __parse_param_elem(param_elem, i=0) -> ToolSourceTestInput: attrib: ToolSourceTestInputAttributes = dict(param_elem.attrib) if "values" in attrib: @@ -1037,7 +1062,7 @@ def __parse_param_elem(param_elem, i=0) -> ToolSourceTestInput: elif child.tag == "edit_attributes": attrib["edit_attributes"].append(child) elif child.tag == "collection": - attrib["collection"] = TestCollectionDef.from_xml(child, __parse_param_elem) + attrib["collection"] = _test_collection_def_dict(child) if composite_data_name: # Composite datasets need implicit renaming; # inserted at front of list so explicit declarations diff --git a/lib/galaxy/tool_util/verify/interactor.py b/lib/galaxy/tool_util/verify/interactor.py index e8971a0ec4f5..86c6f913d4de 100644 --- a/lib/galaxy/tool_util/verify/interactor.py +++ b/lib/galaxy/tool_util/verify/interactor.py @@ -38,6 +38,7 @@ from galaxy.tool_util.parser.interface import ( AssertionList, TestCollectionDef, + TestCollectionDefDict, TestCollectionOutputDef, TestSourceTestOutputColllection, ToolSourceTestOutputs, @@ -1749,7 +1750,8 @@ def expanded_inputs_from_json(expanded_inputs_json: ExpandedToolInputsJsonified) loaded_inputs: ExpandedToolInputs = {} for key, value in expanded_inputs_json.items(): if isinstance(value, dict) and value.get("model_class"): - loaded_inputs[key] = TestCollectionDef.from_dict(value) + collection_def_dict = cast(TestCollectionDefDict, value) + loaded_inputs[key] = TestCollectionDef.from_dict(collection_def_dict) else: loaded_inputs[key] = value return loaded_inputs diff --git a/lib/galaxy/tool_util/verify/parse.py b/lib/galaxy/tool_util/verify/parse.py index 7df3a11d8b1f..eaec6fce31d1 100644 --- a/lib/galaxy/tool_util/verify/parse.py +++ b/lib/galaxy/tool_util/verify/parse.py @@ -10,6 +10,7 @@ from galaxy.tool_util.parser.interface import ( InputSource, + TestCollectionDef, ToolSource, ToolSourceTest, ToolSourceTestInputs, @@ -246,7 +247,8 @@ def _process_raw_inputs( processed_value = param_value elif param_type == "data_collection": assert "collection" in param_extra - collection_def = param_extra["collection"] + collection_dict = param_extra["collection"] + collection_def = TestCollectionDef.from_dict(collection_dict) for input_dict in collection_def.collect_inputs(): name = input_dict["name"] value = input_dict["value"]