Skip to content

Commit

Permalink
Fix type annotation resulting from annotating HDCA.collection
Browse files Browse the repository at this point in the history
  • Loading branch information
mvdbeek committed Jan 10, 2025
1 parent a0919d1 commit 7566674
Show file tree
Hide file tree
Showing 10 changed files with 63 additions and 44 deletions.
15 changes: 8 additions & 7 deletions lib/galaxy/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6845,9 +6845,10 @@ def dataset_elements_and_identifiers(self, identifiers=None):
def first_dataset_element(self) -> Optional["DatasetCollectionElement"]:
for element in self.elements:
if element.is_collection:
first_element = element.child_collection.first_dataset_element
if first_element:
return first_element
if element.child_collection:
first_element = element.child_collection.first_dataset_element
if first_element:
return first_element
else:
return element
return None
Expand Down Expand Up @@ -7437,18 +7438,18 @@ class DatasetCollectionElement(Base, Dictifiable, Serializable):
element_index: Mapped[Optional[int]]
element_identifier: Mapped[Optional[str]] = mapped_column(Unicode(255))

hda = relationship(
hda: Mapped[Optional["HistoryDatasetAssociation"]] = relationship(
"HistoryDatasetAssociation",
primaryjoin=(lambda: DatasetCollectionElement.hda_id == HistoryDatasetAssociation.id),
)
ldda = relationship(
ldda: Mapped[Optional["LibraryDatasetDatasetAssociation"]] = relationship(
"LibraryDatasetDatasetAssociation",
primaryjoin=(lambda: DatasetCollectionElement.ldda_id == LibraryDatasetDatasetAssociation.id),
)
child_collection = relationship(
child_collection: Mapped[Optional["DatasetCollection"]] = relationship(
"DatasetCollection", primaryjoin=(lambda: DatasetCollectionElement.child_collection_id == DatasetCollection.id)
)
collection = relationship(
collection: Mapped[DatasetCollection] = relationship(
"DatasetCollection",
primaryjoin=(lambda: DatasetCollection.id == DatasetCollectionElement.dataset_collection_id),
back_populates="elements",
Expand Down
4 changes: 3 additions & 1 deletion lib/galaxy/tool_util/parser/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,9 @@ def parse_provided_metadata_file(self):
return "galaxy.json"

@abstractmethod
def parse_outputs(self, tool: "Tool") -> Tuple[Dict[str, "ToolOutput"], Dict[str, "ToolOutputCollection"]]:
def parse_outputs(
self, tool: Optional["Tool"]
) -> Tuple[Dict[str, "ToolOutput"], Dict[str, "ToolOutputCollection"]]:
"""Return a pair of output and output collections ordered
dictionaries for use by Tool.
"""
Expand Down
3 changes: 2 additions & 1 deletion lib/galaxy/tool_util/parser/output_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import (
List,
Optional,
Sequence,
Union,
)

Expand Down Expand Up @@ -105,7 +106,7 @@ class FilePatternDatasetCollectionDescription(DatasetCollectionDescription):
ToolOutput = Annotated[ToolOutputT, Field(discriminator="type")]


def from_tool_source(tool_source: ToolSource) -> List[ToolOutput]:
def from_tool_source(tool_source: ToolSource) -> Sequence[ToolOutput]:
tool_outputs, tool_output_collections = tool_source.parse_outputs(None)
outputs = []
for tool_output in tool_outputs.values():
Expand Down
2 changes: 1 addition & 1 deletion lib/galaxy/tool_util/parser/output_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def __init__(
self.collection = True
self.default_format = default_format
self.structure = structure
self.outputs: Dict[str, str] = {}
self.outputs: Dict[str, ToolOutput] = {}

self.inherit_format = inherit_format
self.inherit_metadata = inherit_metadata
Expand Down
35 changes: 20 additions & 15 deletions lib/galaxy/tools/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
safe_makedirs,
unicodify,
)
from galaxy.util.path import StrPath
from galaxy.util.template import (
fill_template,
InputNotFoundSyntaxError,
Expand Down Expand Up @@ -103,7 +104,7 @@ def __init__(self, *args: object, tool_id: Optional[str], tool_version: str, is_
self.is_latest = is_latest


def global_tool_logs(func, config_file: str, action_str: str, tool: "Tool"):
def global_tool_logs(func, config_file: Optional[StrPath], action_str: str, tool: "Tool"):
try:
return func()
except Exception as e:
Expand Down Expand Up @@ -385,6 +386,9 @@ def do_walk(inputs, input_values):
do_walk(inputs, input_values)

def __populate_wrappers(self, param_dict, input_datasets, job_working_directory):

element_identifier_mapper = ElementIdentifierMapper(input_datasets)

def wrap_input(input_values, input):
value = input_values[input.name]
if isinstance(input, DataToolParameter) and input.multiple:
Expand All @@ -401,26 +405,26 @@ def wrap_input(input_values, input):

elif isinstance(input, DataToolParameter):
dataset = input_values[input.name]
wrapper_kwds = dict(
element_identifier = element_identifier_mapper.identifier(dataset, param_dict)
input_values[input.name] = DatasetFilenameWrapper(
dataset=dataset,
datatypes_registry=self.app.datatypes_registry,
tool=self.tool,
name=input.name,
compute_environment=self.compute_environment,
identifier=element_identifier,
formats=input.formats,
)
element_identifier = element_identifier_mapper.identifier(dataset, param_dict)
if element_identifier:
wrapper_kwds["identifier"] = element_identifier
wrapper_kwds["formats"] = input.formats
input_values[input.name] = DatasetFilenameWrapper(dataset, **wrapper_kwds)
elif isinstance(input, DataCollectionToolParameter):
dataset_collection = value
wrapper_kwds = dict(
wrapper = DatasetCollectionWrapper(
job_working_directory=job_working_directory,
has_collection=dataset_collection,
datatypes_registry=self.app.datatypes_registry,
compute_environment=self.compute_environment,
tool=self.tool,
name=input.name,
)
wrapper = DatasetCollectionWrapper(job_working_directory, dataset_collection, **wrapper_kwds)
input_values[input.name] = wrapper
elif isinstance(input, SelectToolParameter):
if input.multiple:
Expand All @@ -430,14 +434,13 @@ def wrap_input(input_values, input):
)
else:
input_values[input.name] = InputValueWrapper(
input, value, param_dict, profile=self.tool and self.tool.profile
input, value, param_dict, profile=self.tool and self.tool.profile or None
)

# HACK: only wrap if check_values is not false, this deals with external
# tools where the inputs don't even get passed through. These
# tools (e.g. UCSC) should really be handled in a special way.
if self.tool.check_values:
element_identifier_mapper = ElementIdentifierMapper(input_datasets)
self.__walk_inputs(self.tool.inputs, param_dict, wrap_input)

def __populate_input_dataset_wrappers(self, param_dict, input_datasets):
Expand All @@ -464,13 +467,13 @@ def __populate_input_dataset_wrappers(self, param_dict, input_datasets):
param_dict[name] = wrapper
continue
if not isinstance(param_dict_value, ToolParameterValueWrapper):
wrapper_kwds = dict(
param_dict[name] = DatasetFilenameWrapper(
dataset=data,
datatypes_registry=self.app.datatypes_registry,
tool=self.tool,
name=name,
compute_environment=self.compute_environment,
)
param_dict[name] = DatasetFilenameWrapper(data, **wrapper_kwds)

def __populate_output_collection_wrappers(self, param_dict, output_collections, job_working_directory):
tool = self.tool
Expand All @@ -481,14 +484,15 @@ def __populate_output_collection_wrappers(self, param_dict, output_collections,
# message = message_template % ( name, tool.output_collections )
# raise AssertionError( message )

wrapper_kwds = dict(
wrapper = DatasetCollectionWrapper(
job_working_directory=job_working_directory,
has_collection=out_collection,
datatypes_registry=self.app.datatypes_registry,
compute_environment=self.compute_environment,
io_type="output",
tool=tool,
name=name,
)
wrapper = DatasetCollectionWrapper(job_working_directory, out_collection, **wrapper_kwds)
param_dict[name] = wrapper
# TODO: Handle nested collections...
for element_identifier, output_def in tool.output_collections[name].outputs.items():
Expand Down Expand Up @@ -683,6 +687,7 @@ def _build_command_line(self):
if interpreter:
# TODO: path munging for cluster/dataset server relocatability
executable = command_line.split()[0]
assert self.tool.tool_dir
tool_dir = os.path.abspath(self.tool.tool_dir)
abs_executable = os.path.join(tool_dir, executable)
command_line = command_line.replace(executable, f"{interpreter} {shlex.quote(abs_executable)}", 1)
Expand Down
35 changes: 21 additions & 14 deletions lib/galaxy/tools/parameters/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1998,6 +1998,7 @@ def do_validate(v):
dataset_count += 1
do_validate(v.hda)
else:
assert v.child_collection
for dataset_instance in v.child_collection.dataset_instances:
dataset_count += 1
do_validate(dataset_instance)
Expand Down Expand Up @@ -2176,33 +2177,39 @@ def from_json(self, value, trans, other_values=None):
dataset_matcher_factory = get_dataset_matcher_factory(trans)
dataset_matcher = dataset_matcher_factory.dataset_matcher(self, other_values)
for v in rval:
value_to_check: Union[
DatasetInstance, DatasetCollection, DatasetCollectionElement, HistoryDatasetCollectionAssociation
] = v
if isinstance(v, DatasetCollectionElement):
if hda := v.hda:
v = hda
value_to_check = hda
elif ldda := v.ldda:
v = ldda
value_to_check = ldda
elif collection := v.child_collection:
v = collection
elif not v.collection and v.collection.populated_optimized:
value_to_check = collection
elif v.collection and not v.collection.populated_optimized:
raise ParameterValueError("the selected collection has not been populated.", self.name)
else:
raise ParameterValueError("Collection element in unexpected state", self.name)
if isinstance(v, DatasetInstance):
if v.deleted:
if isinstance(value_to_check, DatasetInstance):
if value_to_check.deleted:
raise ParameterValueError("the previously selected dataset has been deleted.", self.name)
elif v.dataset and v.dataset.state in [Dataset.states.ERROR, Dataset.states.DISCARDED]:
elif value_to_check.dataset and value_to_check.dataset.state in [
Dataset.states.ERROR,
Dataset.states.DISCARDED,
]:
raise ParameterValueError(
"the previously selected dataset has entered an unusable state", self.name
)
match = dataset_matcher.hda_match(v)
match = dataset_matcher.hda_match(value_to_check)
if match and match.implicit_conversion:
v.implicit_conversion = True # type:ignore[union-attr]
elif isinstance(v, HistoryDatasetCollectionAssociation):
if v.deleted:
value_to_check.implicit_conversion = True # type:ignore[attr-defined]
elif isinstance(value_to_check, HistoryDatasetCollectionAssociation):
if value_to_check.deleted:
raise ParameterValueError("the previously selected dataset collection has been deleted.", self.name)
v = v.collection
if isinstance(v, DatasetCollection):
if v.elements_deleted:
value_to_check = value_to_check.collection
if isinstance(value_to_check, DatasetCollection):
if value_to_check.elements_deleted:
raise ParameterValueError(
"the previously selected dataset collection has elements that are deleted.", self.name
)
Expand Down
6 changes: 4 additions & 2 deletions lib/galaxy/tools/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ def __iter__(self) -> Iterator[Any]:
pass

def _dataset_wrapper(
self, dataset: Union[DatasetInstance, DatasetCollectionElement], **kwargs: Any
self, dataset: Optional[Union[DatasetInstance, DatasetCollectionElement]], **kwargs: Any
) -> DatasetFilenameWrapper:
return DatasetFilenameWrapper(dataset, **kwargs)

Expand Down Expand Up @@ -647,6 +647,7 @@ def __init__(
collection = has_collection.collection
self.name = has_collection.name
elif isinstance(has_collection, DatasetCollectionElement):
assert has_collection.child_collection
collection = has_collection.child_collection
self.name = has_collection.element_identifier
else:
Expand All @@ -661,8 +662,9 @@ def __init__(
for dataset_collection_element in elements:
element_object = dataset_collection_element.element_object
element_identifier = dataset_collection_element.element_identifier
assert element_identifier is not None

if dataset_collection_element.is_collection:
if isinstance(element_object, DatasetCollection):
element_wrapper: DatasetCollectionElementWrapper = DatasetCollectionWrapper(
job_working_directory, dataset_collection_element, **kwargs
)
Expand Down
2 changes: 1 addition & 1 deletion test/unit/app/tools/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def setUp(self):
self.job.history = History()
self.job.history.id = 42
self.job.parameters = [JobParameter(name="thresh", value="4")]
self.evaluator = ToolEvaluator(self.app, self.tool, self.job, self.test_directory)
self.evaluator = ToolEvaluator(self.app, self.tool, self.job, self.test_directory) # type: ignore[arg-type]

def tearDown(self):
self.tear_down_app()
Expand Down
1 change: 1 addition & 0 deletions test/unit/data/test_dataset_materialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,7 @@ def _deferred_element_count(dataset_collection: DatasetCollection) -> int:
count = 0
for element in dataset_collection.elements:
if element.is_collection:
assert element.child_collection
count += _deferred_element_count(element.child_collection)
else:
dataset_instance = element.dataset_instance
Expand Down
4 changes: 2 additions & 2 deletions test/unit/tool_util/test_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from math import isinf
from typing import (
cast,
List,
Optional,
Sequence,
Type,
TypeVar,
)
Expand Down Expand Up @@ -262,7 +262,7 @@ def _tool_source(self):
return self._get_tool_source()

@property
def _output_models(self) -> List[ToolOutput]:
def _output_models(self) -> Sequence[ToolOutput]:
return from_tool_source(self._tool_source)

def _get_tool_source(self, source_file_name=None, source_contents=None, macro_contents=None):
Expand Down

0 comments on commit 7566674

Please sign in to comment.