diff --git a/lib/galaxy/tool_util/parser/interface.py b/lib/galaxy/tool_util/parser/interface.py index 89b17698353e..a7b3896c289f 100644 --- a/lib/galaxy/tool_util/parser/interface.py +++ b/lib/galaxy/tool_util/parser/interface.py @@ -19,6 +19,7 @@ import packaging.version from pydantic import BaseModel from typing_extensions import ( + Literal, NotRequired, TypedDict, ) @@ -549,6 +550,24 @@ def parse_input_sources(self) -> List[InputSource]: """Return a list of InputSource objects.""" +TestCollectionDefElementObject = Union["TestCollectionDefDict", "ToolSourceTestInput"] +TestCollectionAttributeDict = Dict[str, Any] +CollectionType = str + + +class TestCollectionDefElementDict(TypedDict): + element_identifier: str + element_definition: TestCollectionDefElementObject + + +class TestCollectionDefDict(TypedDict): + model_class: Literal["TestCollectionDef"] + attributes: TestCollectionAttributeDict + collection_type: CollectionType + elements: List[TestCollectionDefElementDict] + name: str + + class TestCollectionDef: __test__ = False # Prevent pytest from discovering this class (issue #12071) @@ -558,30 +577,7 @@ def __init__(self, attrib, name, collection_type, elements): self.elements = elements self.name = name - @staticmethod - def from_xml(elem, parse_param_elem): - elements = [] - attrib = dict(elem.attrib) - collection_type = attrib["type"] - name = attrib.get("name", "Unnamed Collection") - for element in elem.findall("element"): - element_attrib = dict(element.attrib) - element_identifier = element_attrib["name"] - nested_collection_elem = element.find("collection") - if nested_collection_elem is not None: - element_definition = TestCollectionDef.from_xml(nested_collection_elem, parse_param_elem) - else: - element_definition = parse_param_elem(element) - elements.append({"element_identifier": element_identifier, "element_definition": element_definition}) - - return TestCollectionDef( - attrib=attrib, - collection_type=collection_type, - elements=elements, - name=name, - ) - - def to_dict(self): + def to_dict(self) -> TestCollectionDefDict: def element_to_dict(element_dict): element_identifier, element_def = element_dict["element_identifier"], element_dict["element_definition"] if isinstance(element_def, TestCollectionDef): @@ -600,7 +596,7 @@ def element_to_dict(element_dict): } @staticmethod - def from_dict(as_dict): + def from_dict(as_dict: TestCollectionDefDict): assert as_dict["model_class"] == "TestCollectionDef" def element_from_dict(element_dict): diff --git a/lib/galaxy/tool_util/parser/xml.py b/lib/galaxy/tool_util/parser/xml.py index ef9b8e5d2b48..9ab4a30f65b6 100644 --- a/lib/galaxy/tool_util/parser/xml.py +++ b/lib/galaxy/tool_util/parser/xml.py @@ -45,7 +45,9 @@ PageSource, PagesSource, RequiredFiles, - TestCollectionDef, + TestCollectionDefDict, + TestCollectionDefElementDict, + TestCollectionDefElementObject, TestCollectionOutputDef, ToolSource, ToolSourceTest, @@ -757,7 +759,7 @@ def __parse_output_elems(test_elem) -> ToolSourceTestOutputs: def __parse_output_elem(output_elem): - attrib = dict(output_elem.attrib) + attrib = _element_to_dict(output_elem) name = attrib.pop("name", None) if name is None: raise Exception("Test output does not have a 'name'") @@ -779,7 +781,7 @@ def __parse_output_collection_elems(test_elem, profile=None): def __parse_output_collection_elem(output_collection_elem, profile=None): - attrib = dict(output_collection_elem.attrib) + attrib = _element_to_dict(output_collection_elem) name = attrib.pop("name", None) if name is None: raise Exception("Test output collection does not have a 'name'") @@ -790,7 +792,7 @@ def __parse_output_collection_elem(output_collection_elem, profile=None): def __parse_element_tests(parent_element, profile=None): element_tests = {} for idx, element in enumerate(parent_element.findall("element")): - element_attrib = dict(element.attrib) + element_attrib: dict = _element_to_dict(element) identifier = element_attrib.pop("name", None) if identifier is None: raise Exception("Test primary dataset does not have a 'identifier'") @@ -861,7 +863,7 @@ def __parse_test_attributes( primary_datasets: Dict[str, Any] = {} if parse_discovered_datasets: for primary_elem in output_elem.findall("discovered_dataset") or []: - primary_attrib = dict(primary_elem.attrib) + primary_attrib = _element_to_dict(primary_elem) designation = primary_attrib.pop("designation", None) if designation is None: raise Exception("Test primary dataset does not have a 'designation'") @@ -911,7 +913,7 @@ def __parse_assert_list_from_elem(assert_elem) -> AssertionList: def convert_elem(elem): """Converts and XML element to a dictionary format, used by assertion checking code.""" tag = elem.tag - attributes = dict(elem.attrib) + attributes = _element_to_dict(elem) converted_children = [] for child_elem in elem: converted_children.append(convert_elem(child_elem)) @@ -928,7 +930,7 @@ def convert_elem(elem): def __parse_extra_files_elem(extra): # File or directory, when directory, compare basename # by basename - attrib = dict(extra.attrib) + attrib = _element_to_dict(extra) extra_type = attrib.pop("type", "file") extra_name = attrib.pop("name", None) assert ( @@ -999,6 +1001,31 @@ def __parse_inputs_elems(test_elem, i) -> ToolSourceTestInputs: return raw_inputs +def _test_collection_def_dict(elem: Element) -> TestCollectionDefDict: + elements: List[TestCollectionDefElementDict] = [] + attrib: Dict[str, Any] = _element_to_dict(elem) + collection_type = attrib["type"] + name = attrib.get("name", "Unnamed Collection") + for element in elem.findall("element"): + element_attrib: Dict[str, Any] = _element_to_dict(element) + element_identifier = element_attrib["name"] + nested_collection_elem = element.find("collection") + element_definition: TestCollectionDefElementObject + if nested_collection_elem is not None: + element_definition = _test_collection_def_dict(nested_collection_elem) + else: + element_definition = __parse_param_elem(element) + elements.append({"element_identifier": element_identifier, "element_definition": element_definition}) + + return TestCollectionDefDict( + model_class="TestCollectionDef", + attributes=attrib, + collection_type=collection_type, + elements=elements, + name=name, + ) + + def __parse_param_elem(param_elem, i=0) -> ToolSourceTestInput: attrib: ToolSourceTestInputAttributes = dict(param_elem.attrib) if "values" in attrib: @@ -1037,7 +1064,7 @@ def __parse_param_elem(param_elem, i=0) -> ToolSourceTestInput: elif child.tag == "edit_attributes": attrib["edit_attributes"].append(child) elif child.tag == "collection": - attrib["collection"] = TestCollectionDef.from_xml(child, __parse_param_elem) + attrib["collection"] = _test_collection_def_dict(child) if composite_data_name: # Composite datasets need implicit renaming; # inserted at front of list so explicit declarations @@ -1546,6 +1573,12 @@ def from_filters(self) -> Optional[DrillDownDynamicFilters]: return self._filters +def _element_to_dict(elem: Element) -> Dict[str, Any]: + # every call to this function needs to be replaced with something more type safe and with + # an actual typed dictionary - but centralizing this hack for now. + return dict(elem.attrib) # type: ignore [arg-type] + + def _recurse_drill_down_elems(options: List[DrillDownOptionsDict], option_elems: List[Element]): for option_elem in option_elems: selected = string_as_bool(option_elem.get("selected", False)) diff --git a/lib/galaxy/tool_util/verify/interactor.py b/lib/galaxy/tool_util/verify/interactor.py index e8971a0ec4f5..86c6f913d4de 100644 --- a/lib/galaxy/tool_util/verify/interactor.py +++ b/lib/galaxy/tool_util/verify/interactor.py @@ -38,6 +38,7 @@ from galaxy.tool_util.parser.interface import ( AssertionList, TestCollectionDef, + TestCollectionDefDict, TestCollectionOutputDef, TestSourceTestOutputColllection, ToolSourceTestOutputs, @@ -1749,7 +1750,8 @@ def expanded_inputs_from_json(expanded_inputs_json: ExpandedToolInputsJsonified) loaded_inputs: ExpandedToolInputs = {} for key, value in expanded_inputs_json.items(): if isinstance(value, dict) and value.get("model_class"): - loaded_inputs[key] = TestCollectionDef.from_dict(value) + collection_def_dict = cast(TestCollectionDefDict, value) + loaded_inputs[key] = TestCollectionDef.from_dict(collection_def_dict) else: loaded_inputs[key] = value return loaded_inputs diff --git a/lib/galaxy/tool_util/verify/parse.py b/lib/galaxy/tool_util/verify/parse.py index 7df3a11d8b1f..eaec6fce31d1 100644 --- a/lib/galaxy/tool_util/verify/parse.py +++ b/lib/galaxy/tool_util/verify/parse.py @@ -10,6 +10,7 @@ from galaxy.tool_util.parser.interface import ( InputSource, + TestCollectionDef, ToolSource, ToolSourceTest, ToolSourceTestInputs, @@ -246,7 +247,8 @@ def _process_raw_inputs( processed_value = param_value elif param_type == "data_collection": assert "collection" in param_extra - collection_def = param_extra["collection"] + collection_dict = param_extra["collection"] + collection_def = TestCollectionDef.from_dict(collection_dict) for input_dict in collection_def.collect_inputs(): name = input_dict["name"] value = input_dict["value"]