From dce53769bd7f323f3a5d38e926da1d5da46ad425 Mon Sep 17 00:00:00 2001 From: John Chilton Date: Wed, 1 Nov 2023 11:52:44 -0400 Subject: [PATCH] Implement default locations for data and collection parameters. Works for both files and collections. Workflow defaults override tool defaults. TODO: - Unit test case to ensure this only works for non-default, non-multi data parameters. - Implement XSD once syntax is finalized. --- lib/galaxy/managers/workflows.py | 7 +- lib/galaxy/model/__init__.py | 18 ++- lib/galaxy/tool_util/parser/interface.py | 3 + lib/galaxy/tool_util/parser/xml.py | 52 +++++++++ lib/galaxy/tool_util/parser/yaml.py | 7 ++ lib/galaxy/tools/parameters/basic.py | 107 +++++++++++++++++- lib/galaxy/workflow/modules.py | 55 +++++---- lib/galaxy/workflow/run.py | 13 ++- lib/galaxy/workflow/run_request.py | 9 +- lib/galaxy_test/api/test_workflows.py | 53 +++++++++ lib/galaxy_test/base/workflow_fixtures.py | 30 +++++ .../tools/collection_nested_default.xml | 50 ++++++++ .../tools/collection_paired_default.xml | 40 +++++++ .../tools/for_workflows/cat_default.xml | 21 ++++ test/functional/tools/sample_tool_conf.xml | 3 + test/unit/tool_util/test_parsing.py | 73 ++++++++++++ test/unit/workflows/test_modules.py | 5 +- test/unit/workflows/test_workflow_progress.py | 7 +- 18 files changed, 513 insertions(+), 40 deletions(-) create mode 100644 test/functional/tools/collection_nested_default.xml create mode 100644 test/functional/tools/collection_paired_default.xml create mode 100644 test/functional/tools/for_workflows/cat_default.xml diff --git a/lib/galaxy/managers/workflows.py b/lib/galaxy/managers/workflows.py index bbb83e316a2e..6de7c805b8be 100644 --- a/lib/galaxy/managers/workflows.py +++ b/lib/galaxy/managers/workflows.py @@ -969,7 +969,7 @@ def _workflow_to_dict_run(self, trans, stored, workflow, history=None): for pja in step.post_job_actions ] else: - inputs = step.module.get_runtime_inputs(connections=step.output_connections) + inputs = step.module.get_runtime_inputs(step, connections=step.output_connections) step_model = {"inputs": [input.to_dict(trans) for input in inputs.values()]} step_model["when"] = step.when_expression step_model["replacement_parameters"] = step.module.get_informal_replacement_parameters(step) @@ -1770,6 +1770,11 @@ def __module_from_dict( if "in" in step_dict: for input_name, input_dict in step_dict["in"].items(): + # This is just a bug in gxformat? I think the input + # defaults should be called input to match the input modules's + # input parameter name. + if input_name == "default": + input_name = "input" step_input = step.get_or_add_input(input_name) NO_DEFAULT_DEFINED = object() default = input_dict.get("default", NO_DEFAULT_DEFINED) diff --git a/lib/galaxy/model/__init__.py b/lib/galaxy/model/__init__.py index 2b8521db0e01..eef4c8b0d125 100644 --- a/lib/galaxy/model/__init__.py +++ b/lib/galaxy/model/__init__.py @@ -7625,10 +7625,20 @@ def input_type(self): @property def input_default_value(self): - tool_state = self.tool_inputs - default_value = tool_state.get("default") - if default_value: - default_value = json.loads(default_value)["value"] + self.get_input_default_value(None) + + def get_input_default_value(self, default_default): + # parameter_input and the data parameters handle this slightly differently + # unfortunately. + if self.type == "parameter_input": + tool_state = self.tool_inputs + default_value = tool_state.get("default", default_default) + else: + default_value = default_default + for step_input in self.inputs: + if step_input.name == "input" and step_input.default_value_set: + default_value = step_input.default_value + break return default_value @property diff --git a/lib/galaxy/tool_util/parser/interface.py b/lib/galaxy/tool_util/parser/interface.py index aba4142edb46..f4c65c13dd0c 100644 --- a/lib/galaxy/tool_util/parser/interface.py +++ b/lib/galaxy/tool_util/parser/interface.py @@ -425,6 +425,9 @@ def parse_test_input_source(self) -> "InputSource": def parse_when_input_sources(self): raise NotImplementedError(NOT_IMPLEMENTED_MESSAGE) + def parse_default(self) -> Optional[Dict[str, Any]]: + return None + class PageSource(metaclass=ABCMeta): def parse_display(self): diff --git a/lib/galaxy/tool_util/parser/xml.py b/lib/galaxy/tool_util/parser/xml.py index 60c6b935a251..3746d4b8163e 100644 --- a/lib/galaxy/tool_util/parser/xml.py +++ b/lib/galaxy/tool_util/parser/xml.py @@ -5,7 +5,9 @@ import re import uuid from typing import ( + Any, cast, + Dict, Iterable, List, Optional, @@ -1274,6 +1276,56 @@ def parse_when_input_sources(self): sources.append((value, case_page_source)) return sources + def parse_default(self) -> Optional[Dict[str, Any]]: + def file_default_from_elem(elem): + # TODO: hashes, created_from_basename, etc... + return {"class": "File", "location": elem.get("location")} + + def read_elements(collection_elem): + element_dicts = [] + elements = collection_elem.findall("element") + for element in elements: + identifier = element.get("name") + subcollection_elem = element.find("collection") + if subcollection_elem: + collection_type = subcollection_elem.get("collection_type") + element_dicts.append( + { + "class": "Collection", + "identifier": identifier, + "collection_type": collection_type, + "elements": read_elements(subcollection_elem), + } + ) + else: + element_dict = file_default_from_elem(element) + element_dict["identifier"] = identifier + element_dicts.append(element_dict) + return element_dicts + + elem = self.input_elem + element_type = self.input_elem.get("type") + if element_type == "data": + default_elem = elem.find("default") + if default_elem is not None: + return file_default_from_elem(default_elem) + else: + return None + else: + default_elem = elem.find("default") + if default_elem is not None: + default_elem = elem.find("default") + collection_type = default_elem.get("collection_type") + name = default_elem.get("name", elem.get("name")) + return { + "class": "Collection", + "name": name, + "collection_type": collection_type, + "elements": read_elements(default_elem), + } + else: + return None + class ParallelismInfo: """ diff --git a/lib/galaxy/tool_util/parser/yaml.py b/lib/galaxy/tool_util/parser/yaml.py index 64c8381b0e37..7ebc574c9ad0 100644 --- a/lib/galaxy/tool_util/parser/yaml.py +++ b/lib/galaxy/tool_util/parser/yaml.py @@ -1,7 +1,9 @@ import json from typing import ( + Any, Dict, List, + Optional, ) import packaging.version @@ -358,6 +360,11 @@ def parse_static_options(self): static_options.append((label, value, selected)) return static_options + def parse_default(self) -> Optional[Dict[str, Any]]: + input_dict = self.input_dict + default_def = input_dict.get("default", None) + return default_def + def _ensure_has(dict, defaults): for key, value in defaults.items(): diff --git a/lib/galaxy/tools/parameters/basic.py b/lib/galaxy/tools/parameters/basic.py index b178568c723e..2da33539c9b9 100644 --- a/lib/galaxy/tools/parameters/basic.py +++ b/lib/galaxy/tools/parameters/basic.py @@ -26,12 +26,16 @@ from galaxy.model import ( cached_id, Dataset, + DatasetCollection, DatasetCollectionElement, + DatasetHash, DatasetInstance, + DatasetSource, HistoryDatasetAssociation, HistoryDatasetCollectionAssociation, LibraryDatasetDatasetAssociation, ) +from galaxy.model.dataset_collections import builder from galaxy.schema.fetch_data import FilesPayload from galaxy.tool_util.parser import get_input_source as ensure_input_source from galaxy.util import ( @@ -43,6 +47,7 @@ ) from galaxy.util.dictifiable import Dictifiable from galaxy.util.expressions import ExpressionContext +from galaxy.util.hash_util import HASH_NAMES from galaxy.util.rules_dsl import RuleSet from . import ( dynamic_options, @@ -2094,6 +2099,11 @@ def __init__(self, tool, input_source, trans=None): self._parse_options(input_source) # Load conversions required for the dataset input self.conversions = [] + self.default_object = input_source.parse_default() + if self.optional and self.default_object is not None: + raise ParameterValueError( + "Cannot specify a Galaxy tool data parameter to be both optional and have a default value.", self.name + ) for name, conv_extension in input_source.parse_conversion_tuples(): assert None not in [ name, @@ -2114,9 +2124,11 @@ def from_json(self, value, trans, other_values=None): other_values = other_values or {} if trans.workflow_building_mode is workflow_building_modes.ENABLED or is_runtime_value(value): return None - if not value and not self.optional: + if not value and not self.optional and not self.default_object: raise ParameterValueError("specify a dataset of the required format / build for parameter", self.name) if value in [None, "None", ""]: + if self.default_object: + return raw_to_galaxy(trans, self.default_object) return None if isinstance(value, dict) and "values" in value: value = self.to_python(value, trans.app) @@ -2411,6 +2423,11 @@ def __init__(self, tool, input_source, trans=None): self.multiple = False # Accessed on DataToolParameter a lot, may want in future self.is_dynamic = True self._parse_options(input_source) # TODO: Review and test. + self.default_object = input_source.parse_default() + if self.optional and self.default_object is not None: + raise ParameterValueError( + "Cannot specify a Galaxy tool data parameter to be both optional and have a default value.", self.name + ) @property def collection_types(self): @@ -2447,9 +2464,11 @@ def from_json(self, value, trans, other_values=None): rval: Optional[Union[DatasetCollectionElement, HistoryDatasetCollectionAssociation]] = None if trans.workflow_building_mode is workflow_building_modes.ENABLED: return None - if not value and not self.optional: + if not value and not self.optional and not self.default_object: raise ParameterValueError("specify a dataset collection of the correct type", self.name) if value in [None, "None"]: + if self.default_object: + return raw_to_galaxy(trans, self.default_object) return None if isinstance(value, dict) and "values" in value: value = self.to_python(value, trans.app) @@ -2664,6 +2683,90 @@ def to_text(self, value): return "" +# Code from CWL branch to massage in order to be shared across tools and workflows, +# and for CWL artifacts as well as Galaxy ones. +def raw_to_galaxy(trans, as_dict_value): + app = trans.app + history = trans.history + + object_class = as_dict_value["class"] + if object_class == "File": + relative_to = "/" # TODO + from galaxy.tool_util.cwl.util import abs_path + + path = abs_path(as_dict_value.get("location"), relative_to) + + name = os.path.basename(path) + extension = as_dict_value.get("format") or "data" + dataset = Dataset() + source = DatasetSource() + source.source_uri = path + # TODO: validate this... + source.transform = as_dict_value.get("transform") + dataset.sources.append(source) + + for hash_name in HASH_NAMES: + # TODO: Convert md5 -> MD5 during tool parsing. + if hash_name in as_dict_value: + hash_object = DatasetHash() + hash_object.hash_function = hash_name + hash_object.hash_value = as_dict_value[hash_name] + dataset.hashes.append(hash_object) + + if "created_from_basename" in as_dict_value: + dataset.created_from_basename = as_dict_value["created_from_basename"] + + dataset.state = Dataset.states.DEFERRED + primary_data = HistoryDatasetAssociation( + name=name, + extension=extension, + metadata_deferred=True, + designation=None, + visible=True, + dbkey="?", + dataset=dataset, + flush=False, + sa_session=trans.sa_session, + ) + primary_data.state = Dataset.states.DEFERRED + permissions = app.security_agent.history_get_default_permissions(history) + app.security_agent.set_all_dataset_permissions(primary_data.dataset, permissions, new=True, flush=False) + trans.sa_session.add(primary_data) + history.stage_addition(primary_data) + history.add_pending_items() + trans.sa_session.flush() + return primary_data + else: + name = as_dict_value.get("name") + collection_type = as_dict_value.get("collection_type") + collection = DatasetCollection( + collection_type=collection_type, + ) + hdca = HistoryDatasetCollectionAssociation( + name=name, + collection=collection, + ) + + def write_elements_to_collection(has_elements, collection_builder): + element_dicts = has_elements.get("elements") + for element_dict in element_dicts: + element_class = element_dict["class"] + identifier = element_dict["identifier"] + if element_class == "File": + hda = raw_to_galaxy(trans, element_dict) + collection_builder.add_dataset(identifier, hda) + else: + subcollection_builder = collection_builder.get_level(identifier) + write_elements_to_collection(element_dict, subcollection_builder) + + collection_builder = builder.BoundCollectionBuilder(collection) + write_elements_to_collection(as_dict_value, collection_builder) + collection_builder.populate() + trans.sa_session.add(hdca) + trans.sa_session.flush() + return hdca + + parameter_types = dict( text=TextToolParameter, integer=IntegerToolParameter, diff --git a/lib/galaxy/workflow/modules.py b/lib/galaxy/workflow/modules.py index 805a8ba3f2f6..ce599ba6ec4c 100644 --- a/lib/galaxy/workflow/modules.py +++ b/lib/galaxy/workflow/modules.py @@ -75,6 +75,7 @@ IntegerToolParameter, is_runtime_value, parameter_types, + raw_to_galaxy, runtime_to_json, SelectToolParameter, TextToolParameter, @@ -395,7 +396,7 @@ def get_config_form(self, step=None): def get_runtime_state(self) -> DefaultToolState: raise TypeError("Abstract method") - def get_runtime_inputs(self, **kwds): + def get_runtime_inputs(self, step, connections: Optional[Iterable[WorkflowStepConnection]] = None): """Used internally by modules and when displaying inputs in workflow editor and run workflow templates. """ @@ -432,7 +433,7 @@ def update_value(input, context, prefixed_name, **kwargs): return NO_REPLACEMENT visit_input_values( - self.get_runtime_inputs(connections=step.output_connections), + self.get_runtime_inputs(step, connections=step.output_connections), state.inputs, update_value, no_replacement_value=NO_REPLACEMENT, @@ -449,19 +450,19 @@ def update_value(input, context, prefixed_name, **kwargs): return NO_REPLACEMENT visit_input_values( - self.get_runtime_inputs(), state.inputs, update_value, no_replacement_value=NO_REPLACEMENT + self.get_runtime_inputs(step), state.inputs, update_value, no_replacement_value=NO_REPLACEMENT ) return state, step_errors - def encode_runtime_state(self, runtime_state): + def encode_runtime_state(self, step, runtime_state): """Takes the computed runtime state and serializes it during run request creation.""" - return runtime_state.encode(Bunch(inputs=self.get_runtime_inputs()), self.trans.app) + return runtime_state.encode(Bunch(inputs=self.get_runtime_inputs(step)), self.trans.app) - def decode_runtime_state(self, runtime_state): + def decode_runtime_state(self, step, runtime_state): """Takes the serialized runtime state and decodes it when running the workflow.""" state = DefaultToolState() - state.decode(runtime_state, Bunch(inputs=self.get_runtime_inputs()), self.trans.app) + state.decode(runtime_state, Bunch(inputs=self.get_runtime_inputs(step)), self.trans.app) return state def execute(self, trans, progress, invocation_step, use_cached_job=False): @@ -533,7 +534,7 @@ def _find_collections_to_match(self, progress, step, all_inputs): for input_dict in all_inputs: name = input_dict["name"] - data = progress.replacement_for_input(step, input_dict) + data = progress.replacement_for_input(self.trans, step, input_dict) can_map_over = hasattr(data, "collection") # and data.collection.allow_implicit_mapping if not can_map_over: @@ -825,7 +826,7 @@ def get_runtime_state(self): state.inputs = dict() return state - def get_runtime_inputs(self, connections=None): + def get_runtime_inputs(self, step, connections: Optional[Iterable[WorkflowStepConnection]] = None): inputs = {} for step in self.subworkflow.steps: if step.type == "tool": @@ -932,7 +933,13 @@ def get_all_inputs(self, data_only=False, connectable_only=False): def execute(self, trans, progress, invocation_step, use_cached_job=False): invocation = invocation_step.workflow_invocation step = invocation_step.workflow_step - step_outputs = dict(output=step.state.inputs["input"]) + input_value = step.state.inputs["input"] + if input_value is None: + default_value = step.get_input_default_value(NO_REPLACEMENT) + if default_value is not NO_REPLACEMENT: + input_value = raw_to_galaxy(trans, default_value) + + step_outputs = dict(output=input_value) # Web controller may set copy_inputs_to_history, API controller always sets # inputs. @@ -1025,7 +1032,7 @@ def get_filter_set(self, connections=None): filter_set = ["data"] return ", ".join(filter_set) - def get_runtime_inputs(self, connections=None): + def get_runtime_inputs(self, step, connections: Optional[Iterable[WorkflowStepConnection]] = None): parameter_def = self._parse_state_into_dict() optional = parameter_def["optional"] tag = parameter_def["tag"] @@ -1037,6 +1044,10 @@ def get_runtime_inputs(self, connections=None): data_src = dict( name="input", label=self.label, multiple=False, type="data", format=formats, tag=tag, optional=optional ) + default_unset = object() + default = step.get_input_default_value(default_unset) + if default is not default_unset: + data_src["default"] = default input_param = DataToolParameter(None, data_src, self.trans) return dict(input=input_param) @@ -1096,7 +1107,7 @@ def get_inputs(self): inputs["tag"] = input_tag return inputs - def get_runtime_inputs(self, **kwds): + def get_runtime_inputs(self, step, connections: Optional[Iterable[WorkflowStepConnection]] = None): parameter_def = self._parse_state_into_dict() collection_type = parameter_def["collection_type"] optional = parameter_def["optional"] @@ -1366,7 +1377,7 @@ def get_inputs(self): parameter_type_cond.cases = cases return {"parameter_definition": parameter_type_cond} - def restrict_options(self, connections: Iterable[WorkflowStepConnection], default_value): + def restrict_options(self, step, connections: Iterable[WorkflowStepConnection], default_value): try: static_options = [] # Retrieve possible runtime options for 'select' type inputs @@ -1389,7 +1400,7 @@ def callback(input, prefixed_name, context, **kwargs): for step in module.subworkflow.input_steps: if step.input_type == "parameter" and step.label == subworkflow_input_name: static_options.append( - step.module.get_runtime_inputs(connections=step.output_connections)[ + step.module.get_runtime_inputs(step, connections=step.output_connections)[ "input" ].static_options ) @@ -1421,7 +1432,7 @@ def callback(input, prefixed_name, context, **kwargs): except Exception: log.debug("Failed to generate options for text parameter, falling back to free text.", exc_info=True) - def get_runtime_inputs(self, connections: Optional[Iterable[WorkflowStepConnection]] = None, **kwds): + def get_runtime_inputs(self, step, connections: Optional[Iterable[WorkflowStepConnection]] = None): parameter_def = self._parse_state_into_dict() parameter_type = parameter_def["parameter_type"] optional = parameter_def["optional"] @@ -1440,7 +1451,7 @@ def get_runtime_inputs(self, connections: Optional[Iterable[WorkflowStepConnecti attemptRestrictOnConnections = is_text and parameter_def.get("restrictOnConnections") and connections if attemptRestrictOnConnections: connections = cast(Iterable[WorkflowStepConnection], connections) - restricted_options = self.restrict_options(connections=connections, default_value=default_value) + restricted_options = self.restrict_options(step, connections=connections, default_value=default_value) if restricted_options is not None: restricted_inputs = True parameter_kwds["options"] = restricted_options @@ -1518,7 +1529,7 @@ def execute(self, trans, progress, invocation_step, use_cached_job=False): step = invocation_step.workflow_step input_value = step.state.inputs["input"] if input_value is None: - default_value = safe_loads(step.tool_inputs.get("default", "{}")) + default_value = step.get_input_default_value(NO_REPLACEMENT) # TODO: look at parameter type and infer if value should be a dictionary # instead. Guessing only field parameter types in CWL branch would have # default as dictionary like this. @@ -1985,7 +1996,7 @@ def callback(input, prefixed_name, context, **kwargs): output_step for output_step in steps if connection.output_step_id == output_step.id ) if output_step.type.startswith("data"): - output_inputs = output_step.module.get_runtime_inputs(connections=connections) + output_inputs = output_step.module.get_runtime_inputs(output_step, connections=connections) output_value = output_inputs["input"].get_initial_value(self.trans, context) if input_type == "data" and isinstance( output_value, self.trans.app.model.HistoryDatasetCollectionAssociation @@ -2087,7 +2098,7 @@ def get_runtime_state(self): state.inputs = self.state.inputs return state - def get_runtime_inputs(self, **kwds): + def get_runtime_inputs(self, step, connections: Optional[Iterable[WorkflowStepConnection]] = None): return self.get_inputs() def compute_runtime_state(self, trans, step=None, step_updates=None): @@ -2108,12 +2119,12 @@ def compute_runtime_state(self, trans, step=None, step_updates=None): f"Tool {self.tool_id} missing. Cannot compute runtime state.", tool_id=self.tool_id ) - def decode_runtime_state(self, runtime_state): + def decode_runtime_state(self, step, runtime_state): """Take runtime state from persisted invocation and convert it into a DefaultToolState object for use during workflow invocation. """ if self.tool: - state = super().decode_runtime_state(runtime_state) + state = super().decode_runtime_state(step, runtime_state) if RUNTIME_STEP_META_STATE_KEY in runtime_state: self.__restore_step_meta_runtime_state(json.loads(runtime_state[RUNTIME_STEP_META_STATE_KEY])) return state @@ -2169,7 +2180,7 @@ def callback(input, prefixed_name, **kwargs): if iteration_elements and prefixed_name in iteration_elements: # noqa: B023 replacement = iteration_elements[prefixed_name] # noqa: B023 else: - replacement = progress.replacement_for_input(step, input_dict) + replacement = progress.replacement_for_input(trans, step, input_dict) if replacement is not NO_REPLACEMENT: if not isinstance(input, BaseDataToolParameter): diff --git a/lib/galaxy/workflow/run.py b/lib/galaxy/workflow/run.py index d69fe79b6aa2..3cb73768e997 100644 --- a/lib/galaxy/workflow/run.py +++ b/lib/galaxy/workflow/run.py @@ -31,6 +31,7 @@ InvocationWarningWorkflowOutputNotFound, WarningReason, ) +from galaxy.tools.parameters.basic import raw_to_galaxy from galaxy.util import ExecutionTimer from galaxy.workflow import modules from galaxy.workflow.run_request import ( @@ -399,7 +400,7 @@ def remaining_steps( raise MessageException(public_message) runtime_state = step_states[step_id].value assert step.module - step.state = step.module.decode_runtime_state(runtime_state) + step.state = step.module.decode_runtime_state(step, runtime_state) invocation_step = step_invocations_by_id.get(step_id, None) if invocation_step and invocation_step.state == "scheduled": @@ -408,7 +409,7 @@ def remaining_steps( remaining_steps.append((step, invocation_step)) return remaining_steps - def replacement_for_input(self, step: "WorkflowStep", input_dict: Dict[str, Any]) -> Any: + def replacement_for_input(self, trans, step: "WorkflowStep", input_dict: Dict[str, Any]) -> Any: replacement: Union[ modules.NoReplacement, model.DatasetCollectionInstance, @@ -416,6 +417,7 @@ def replacement_for_input(self, step: "WorkflowStep", input_dict: Dict[str, Any] ] = modules.NO_REPLACEMENT prefixed_name = input_dict["name"] multiple = input_dict["multiple"] + is_data = input_dict["input_type"] in ["dataset", "dataset_collection"] if prefixed_name in step.input_connections_by_name: connection = step.input_connections_by_name[prefixed_name] if input_dict["input_type"] == "dataset" and multiple: @@ -431,9 +433,12 @@ def replacement_for_input(self, step: "WorkflowStep", input_dict: Dict[str, Any] else: replacement = temp else: - is_data = input_dict["input_type"] in ["dataset", "dataset_collection"] replacement = self.replacement_for_connection(connection[0], is_data=is_data) - + else: + for step_input in step.inputs: + if step_input.name == prefixed_name and step_input.default_value_set: + if is_data: + replacement = raw_to_galaxy(trans, step_input.default_value) return replacement def replacement_for_connection(self, connection: "WorkflowStepConnection", is_data: bool = True) -> Any: diff --git a/lib/galaxy/workflow/run_request.py b/lib/galaxy/workflow/run_request.py index d3b5a6039d48..4f38370ddc62 100644 --- a/lib/galaxy/workflow/run_request.py +++ b/lib/galaxy/workflow/run_request.py @@ -119,13 +119,16 @@ def _normalize_inputs( for possible_input_key in possible_input_keys: if possible_input_key in inputs: inputs_key = possible_input_key - default_value = step.tool_inputs.get("default") + + default_not_set = object() + has_default = step.get_input_default_value(default_not_set) is not default_not_set optional = step.input_optional # Need to be careful here to make sure 'default' has correct type - not sure how to do that # but asserting 'optional' is definitely a bool and not a String->Bool or something is a good # start to ensure tool state is being preserved and loaded in a type safe way. assert isinstance(optional, bool) - if not inputs_key and default_value is None and not optional: + assert isinstance(has_default, bool) + if not inputs_key and not has_default and not optional: message = f"Workflow cannot be run because an expected input step '{step.id}' ({step.label}) is not optional and no input." raise exceptions.MessageException(message) if inputs_key: @@ -495,7 +498,7 @@ def add_parameter(name: str, value: str, type: WorkflowRequestInputParameter.typ for step in workflow.steps: steps_by_id[step.id] = step assert step.module - serializable_runtime_state = step.module.encode_runtime_state(step.state) + serializable_runtime_state = step.module.encode_runtime_state(step, step.state) step_state = WorkflowRequestStepState() step_state.workflow_step = step diff --git a/lib/galaxy_test/api/test_workflows.py b/lib/galaxy_test/api/test_workflows.py index 4483a096ef6a..276b989460ca 100644 --- a/lib/galaxy_test/api/test_workflows.py +++ b/lib/galaxy_test/api/test_workflows.py @@ -57,11 +57,13 @@ WORKFLOW_WITH_BAD_COLUMN_PARAMETER_GOOD_TEST_DATA, WORKFLOW_WITH_CUSTOM_REPORT_1, WORKFLOW_WITH_CUSTOM_REPORT_1_TEST_DATA, + WORKFLOW_WITH_DEFAULT_FILE_DATASET_INPUT, WORKFLOW_WITH_DYNAMIC_OUTPUT_COLLECTION, WORKFLOW_WITH_MAPPED_OUTPUT_COLLECTION, WORKFLOW_WITH_OUTPUT_COLLECTION, WORKFLOW_WITH_OUTPUT_COLLECTION_MAPPING, WORKFLOW_WITH_RULES_1, + WORKFLOW_WITH_STEP_DEFAULT_FILE_DATASET_INPUT, ) from ._framework import ApiTestCase from .sharable import SharingApiTests @@ -4694,6 +4696,57 @@ def test_run_with_validated_parameter_connection_default_values(self): content = self.dataset_populator.get_history_dataset_content(history_id) assert len(content.splitlines()) == 3, content + def test_run_with_default_file_dataset_input(self): + with self.dataset_populator.test_history() as history_id: + run_response = self._run_workflow( + WORKFLOW_WITH_DEFAULT_FILE_DATASET_INPUT, + history_id=history_id, + wait=True, + assert_ok=True, + ) + invocation_details = self.workflow_populator.get_invocation(run_response.invocation_id, step_details=True) + assert invocation_details["steps"][0]["outputs"]["output"]["src"] == "hda" + dataset_details = self.dataset_populator.get_history_dataset_details( + history_id, dataset_id=invocation_details["steps"][1]["outputs"]["out_file1"]["id"] + ) + assert dataset_details["file_ext"] == "txt" + assert "chr1" in dataset_details["peek"] + + def test_run_with_default_file_dataset_input_and_explicit_input(self): + with self.dataset_populator.test_history() as history_id: + run_response = self._run_workflow( + WORKFLOW_WITH_DEFAULT_FILE_DATASET_INPUT, + test_data=""" +default_file_input: + value: 1.fasta + type: File +""", + history_id=history_id, + wait=True, + assert_ok=True, + ) + invocation_details = self.workflow_populator.get_invocation(run_response.invocation_id, step_details=True) + assert invocation_details["steps"][0]["outputs"]["output"]["src"] == "hda" + dataset_details = self.dataset_populator.get_history_dataset_details( + history_id, dataset_id=invocation_details["steps"][1]["outputs"]["out_file1"]["id"] + ) + assert dataset_details["file_ext"] == "txt" + assert ( + "gtttgccatcttttgctgctctagggaatccagcagctgtcaccatgtaaacaagcccaggctagaccaGTTACCCTCATCATCTTAGCTGATAGCCAGCCAGCCACCACAGGCA" + in dataset_details["peek"] + ) + + def test_run_with_default_file_in_step_inline(self): + with self.dataset_populator.test_history() as history_id: + self._run_workflow( + WORKFLOW_WITH_STEP_DEFAULT_FILE_DATASET_INPUT, + history_id=history_id, + wait=True, + assert_ok=True, + ) + content = self.dataset_populator.get_history_dataset_content(history_id) + assert "chr1" in content + def test_run_with_validated_parameter_connection_invalid(self): with self.dataset_populator.test_history() as history_id: self._run_jobs( diff --git a/lib/galaxy_test/base/workflow_fixtures.py b/lib/galaxy_test/base/workflow_fixtures.py index a01d39ce52c3..912adecb32f9 100644 --- a/lib/galaxy_test/base/workflow_fixtures.py +++ b/lib/galaxy_test/base/workflow_fixtures.py @@ -1147,3 +1147,33 @@ outer_output_2: outputSource: subworkflow/inner_output_2 """ + +WORKFLOW_WITH_DEFAULT_FILE_DATASET_INPUT = """ +class: GalaxyWorkflow +inputs: + default_file_input: + default: + class: File + basename: a file + format: txt + location: https://raw.githubusercontent.com/galaxyproject/galaxy/dev/test-data/1.bed +steps: + cat1: + tool_id: cat1 + in: + input1: default_file_input +""" + +WORKFLOW_WITH_STEP_DEFAULT_FILE_DATASET_INPUT = """ +class: GalaxyWorkflow +steps: + cat1: + tool_id: cat1 + in: + input1: + default: + class: File + basename: a file + format: txt + location: https://raw.githubusercontent.com/galaxyproject/galaxy/dev/test-data/1.bed +""" diff --git a/test/functional/tools/collection_nested_default.xml b/test/functional/tools/collection_nested_default.xml new file mode 100644 index 000000000000..dddfa9a8eb80 --- /dev/null +++ b/test/functional/tools/collection_nested_default.xml @@ -0,0 +1,50 @@ + + + echo #for $f in $f1# ${f.is_collection} #end for# >> $out1; + cat #for $f in $f1# #if $f.is_collection# #for $inner in $f# ${inner} #end for# #else# $f # #end if# #end for# >> $out2 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/test/functional/tools/collection_paired_default.xml b/test/functional/tools/collection_paired_default.xml new file mode 100644 index 000000000000..cfe47d985f03 --- /dev/null +++ b/test/functional/tools/collection_paired_default.xml @@ -0,0 +1,40 @@ + + + cat $f1.forward $f1['reverse'] >> $out1; + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/test/functional/tools/for_workflows/cat_default.xml b/test/functional/tools/for_workflows/cat_default.xml new file mode 100644 index 000000000000..19de2d7cc3ba --- /dev/null +++ b/test/functional/tools/for_workflows/cat_default.xml @@ -0,0 +1,21 @@ + + + '$out_file1' + ]]> + + + + + + + + + + + + + + + + diff --git a/test/functional/tools/sample_tool_conf.xml b/test/functional/tools/sample_tool_conf.xml index 6069630972b8..f6be322f2b3e 100644 --- a/test/functional/tools/sample_tool_conf.xml +++ b/test/functional/tools/sample_tool_conf.xml @@ -176,9 +176,11 @@ + + @@ -236,6 +238,7 @@ parameter, and multiple datasets from a collection. --> + diff --git a/test/unit/tool_util/test_parsing.py b/test/unit/tool_util/test_parsing.py index 9ef3a3a23893..25575348636a 100644 --- a/test/unit/tool_util/test_parsing.py +++ b/test/unit/tool_util/test_parsing.py @@ -693,6 +693,79 @@ def test_test(self): assert output0["attributes"]["object"] is None +class TestDefaultDataTestToolLoader(BaseLoaderTestCase): + source_file_name = os.path.join(galaxy_directory(), "test/functional/tools/for_workflows/cat_default.xml") + source_contents = None + + def test_input_parsing(self): + input_pages = self._tool_source.parse_input_pages() + assert input_pages.inputs_defined + page_sources = input_pages.page_sources + assert len(page_sources) == 1 + page_source = page_sources[0] + input_sources = page_source.parse_input_sources() + assert len(input_sources) == 1 + data_input = input_sources[0] + default_dict = data_input.parse_default() + assert default_dict + assert default_dict["location"] == "https://raw.githubusercontent.com/galaxyproject/galaxy/dev/test-data/1.bed" + + +class TestDefaultCollectionDataTestToolLoader(BaseLoaderTestCase): + source_file_name = os.path.join(galaxy_directory(), "test/functional/tools/collection_paired_default.xml") + source_contents = None + + def test_input_parsing(self): + input_pages = self._tool_source.parse_input_pages() + assert input_pages.inputs_defined + page_sources = input_pages.page_sources + assert len(page_sources) == 1 + page_source = page_sources[0] + input_sources = page_source.parse_input_sources() + assert len(input_sources) == 1 + data_input = input_sources[0] + default_dict = data_input.parse_default() + assert default_dict + assert default_dict["collection_type"] == "paired" + elements = default_dict["elements"] + assert len(elements) == 2 + element0 = elements[0] + assert element0["identifier"] == "forward" + assert element0["location"] == "https://raw.githubusercontent.com/galaxyproject/galaxy/dev/test-data/1.bed" + element1 = elements[1] + assert element1["identifier"] == "reverse" + assert element1["location"] == "https://raw.githubusercontent.com/galaxyproject/galaxy/dev/test-data/1.fasta" + + +class TestDefaultNestedCollectionDataTestToolLoader(BaseLoaderTestCase): + source_file_name = os.path.join(galaxy_directory(), "test/functional/tools/collection_nested_default.xml") + source_contents = None + + def test_input_parsing(self): + input_pages = self._tool_source.parse_input_pages() + assert input_pages.inputs_defined + page_sources = input_pages.page_sources + assert len(page_sources) == 1 + page_source = page_sources[0] + input_sources = page_source.parse_input_sources() + assert len(input_sources) == 1 + data_input = input_sources[0] + default_dict = data_input.parse_default() + assert default_dict + assert default_dict["collection_type"] == "list:paired" + elements = default_dict["elements"] + assert len(elements) == 1 + element0 = elements[0] + assert element0["identifier"] == "i1" + + elements0 = element0["elements"] + assert len(elements0) == 2 + elements00 = elements0[0] + assert elements00["identifier"] == "forward" + elements01 = elements0[1] + assert elements01["identifier"] == "reverse" + + class TestExpressionOutputDataToolLoader(BaseLoaderTestCase): source_file_name = os.path.join(galaxy_directory(), "test/functional/tools/expression_pick_larger_file.xml") source_contents = None diff --git a/test/unit/workflows/test_modules.py b/test/unit/workflows/test_modules.py index 51dc049b7562..27955d3d4f96 100644 --- a/test/unit/workflows/test_modules.py +++ b/test/unit/workflows/test_modules.py @@ -427,7 +427,10 @@ def __new_subworkflow_module(workflow=TEST_WORKFLOW_YAML): def __assert_has_runtime_input(module, label=None, collection_type=None): - inputs = module.get_runtime_inputs() + test_step = getattr(module, "test_step", None) + if test_step is None: + test_step = mock.MagicMock() + inputs = module.get_runtime_inputs(test_step) assert len(inputs) == 1 assert "input" in inputs input_param = inputs["input"] diff --git a/test/unit/workflows/test_workflow_progress.py b/test/unit/workflows/test_workflow_progress.py index 1d44868ff7f6..5abe1013f106 100644 --- a/test/unit/workflows/test_workflow_progress.py +++ b/test/unit/workflows/test_workflow_progress.py @@ -132,7 +132,7 @@ def test_replacement_for_tool_input(self): "input_type": "dataset", "multiple": False, } - replacement = progress.replacement_for_input(self._step(2), step_dict) + replacement = progress.replacement_for_input(None, self._step(2), step_dict) assert replacement is hda def test_connect_tool_output(self): @@ -169,7 +169,7 @@ def test_remaining_steps_with_progress(self): "input_type": "dataset", "multiple": False, } - replacement = progress.replacement_for_input(self._step(4), step_dict) + replacement = progress.replacement_for_input(None, self._step(4), step_dict) assert replacement is hda3 # TODO: Replace multiple true HDA with HDCA @@ -216,6 +216,7 @@ def test_subworkflow_progress(self): "multiple": False, } assert hda is subworkflow_progress.replacement_for_input( + None, subworkflow_cat_step, step_dict, ) @@ -242,7 +243,7 @@ class MockModule: def __init__(self, progress): self.progress = progress - def decode_runtime_state(self, runtime_state): + def decode_runtime_state(self, step, runtime_state): return True def recover_mapping(self, invocation_step, progress):