From 352b44e8848ccae1b88d4d4c6b0bdef29c5f8588 Mon Sep 17 00:00:00 2001 From: John Chilton Date: Thu, 29 Aug 2024 13:54:12 -0400 Subject: [PATCH] Typing around tool state. --- lib/galaxy/managers/jobs.py | 36 ++- lib/galaxy/tools/__init__.py | 218 ++++++++++++------- lib/galaxy/tools/_types.py | 65 ++++++ lib/galaxy/tools/actions/__init__.py | 23 +- lib/galaxy/tools/actions/data_manager.py | 4 +- lib/galaxy/tools/actions/history_imp_exp.py | 6 +- lib/galaxy/tools/actions/metadata.py | 4 +- lib/galaxy/tools/actions/model_operations.py | 4 +- lib/galaxy/tools/actions/upload.py | 4 +- lib/galaxy/tools/execute.py | 74 +++++-- lib/galaxy/tools/parameters/__init__.py | 179 ++++++++++----- lib/galaxy/tools/parameters/basic.py | 1 + lib/galaxy/tools/parameters/grouping.py | 43 ++-- lib/galaxy/tools/parameters/meta.py | 8 +- lib/galaxy/webapps/galaxy/api/jobs.py | 8 +- lib/galaxy/webapps/galaxy/api/workflows.py | 3 +- lib/galaxy/work/context.py | 2 +- lib/galaxy/workflow/modules.py | 27 +-- test/unit/app/tools/test_evaluation.py | 6 +- 19 files changed, 481 insertions(+), 234 deletions(-) create mode 100644 lib/galaxy/tools/_types.py diff --git a/lib/galaxy/managers/jobs.py b/lib/galaxy/managers/jobs.py index 546085b26a2e..b7221cf0fef9 100644 --- a/lib/galaxy/managers/jobs.py +++ b/lib/galaxy/managers/jobs.py @@ -9,6 +9,7 @@ Dict, List, Optional, + Union, ) import sqlalchemy @@ -41,7 +42,10 @@ Safety, ) from galaxy.managers.collections import DatasetCollectionManager -from galaxy.managers.context import ProvidesUserContext +from galaxy.managers.context import ( + ProvidesHistoryContext, + ProvidesUserContext, +) from galaxy.managers.datasets import DatasetManager from galaxy.managers.hdas import HDAManager from galaxy.managers.lddas import LDDAManager @@ -68,6 +72,10 @@ ) from galaxy.security.idencoding import IdEncodingHelper from galaxy.structured_app import StructuredApp +from galaxy.tools._types import ( + ToolStateDumpedToJsonInternalT, + ToolStateJobInstancePopulatedT, +) from galaxy.util import ( defaultdict, ExecutionTimer, @@ -81,6 +89,9 @@ log = logging.getLogger(__name__) +JobStateT = str +JobStatesT = Union[JobStateT, List[JobStateT]] + class JobLock(BaseModel): active: bool = Field(title="Job lock status", description="If active, jobs will not dispatch") @@ -311,7 +322,15 @@ def __init__( self.ldda_manager = ldda_manager self.decode_id = id_encoding_helper.decode_id - def by_tool_input(self, trans, tool_id, tool_version, param=None, param_dump=None, job_state="ok"): + def by_tool_input( + self, + trans: ProvidesHistoryContext, + tool_id: str, + tool_version: Optional[str], + param: ToolStateJobInstancePopulatedT, + param_dump: ToolStateDumpedToJsonInternalT, + job_state: Optional[JobStatesT] = "ok", + ): """Search for jobs producing same results using the 'inputs' part of a tool POST.""" user = trans.user input_data = defaultdict(list) @@ -353,7 +372,14 @@ def populate_input_data_input_id(path, key, value): ) def __search( - self, tool_id, tool_version, user, input_data, job_state=None, param_dump=None, wildcard_param_dump=None + self, + tool_id: str, + tool_version: Optional[str], + user: model.User, + input_data, + job_state: Optional[JobStatesT], + param_dump: ToolStateDumpedToJsonInternalT, + wildcard_param_dump=None, ): search_timer = ExecutionTimer() @@ -462,7 +488,9 @@ def replace_dataset_ids(path, key, value): log.info("No equivalent jobs found %s", search_timer) return None - def _build_job_subquery(self, tool_id, user_id, tool_version, job_state, wildcard_param_dump): + def _build_job_subquery( + self, tool_id: str, user_id: int, tool_version: Optional[str], job_state, wildcard_param_dump + ): """Build subquery that selects a job with correct job parameters.""" stmt = select(model.Job.id).where( and_( diff --git a/lib/galaxy/tools/__init__.py b/lib/galaxy/tools/__init__.py index 7193fd6ea618..6933b09d1809 100644 --- a/lib/galaxy/tools/__init__.py +++ b/lib/galaxy/tools/__init__.py @@ -120,6 +120,8 @@ check_param, params_from_strings, params_to_incoming, + params_to_json, + params_to_json_internal, params_to_strings, populate_state, visit_input_values, @@ -183,7 +185,19 @@ get_tool_shed_url_from_tool_shed_registry, ) from galaxy.version import VERSION_MAJOR -from galaxy.work.context import proxy_work_context_for_history +from galaxy.work.context import ( + proxy_work_context_for_history, + WorkRequestContext, +) +from ._types import ( + InputFormatT, + ParameterValidationErrorsT, + ToolRequestT, + ToolStateDumpedToJsonInternalT, + ToolStateDumpedToJsonT, + ToolStateJobInstancePopulatedT, + ToolStateJobInstanceT, +) from .execute import ( DatasetCollectionElementsSliceT, DEFAULT_JOB_CALLBACK, @@ -195,8 +209,6 @@ ExecutionSlice, JobCallbackT, MappingParameters, - ToolParameterRequestInstanceT, - ToolParameterRequestT, ) if TYPE_CHECKING: @@ -712,7 +724,7 @@ def encode(self, tool, app, nested=False): """ Convert the data to a string """ - value = params_to_strings(tool.inputs, self.inputs, app, nested=nested) + value = cast(Dict[str, Any], params_to_strings(tool.inputs, self.inputs, app, nested=nested)) value["__page__"] = self.page value["__rerun_remap_job_id__"] = self.rerun_remap_job_id return value @@ -1514,8 +1526,8 @@ def parse_input_elem( # Repeat group input_type = input_source.parse_input_type() if input_type == "repeat": - group_r = Repeat() - group_r.name = input_source.get("name") + repeat_name = input_source.get("name") + group_r = Repeat(repeat_name) group_r.title = input_source.get("title") group_r.help = input_source.get("help", None) page_source = input_source.parse_nested_inputs_source() @@ -1531,15 +1543,17 @@ def parse_input_elem( group_r.default = cast(int, min(max(group_r.default, group_r.min), group_r.max)) rval[group_r.name] = group_r elif input_type == "conditional": - group_c = Conditional() - group_c.name = input_source.get("name") - group_c.value_ref = input_source.get("value_ref", None) + cond_name = input_source.get("name") + group_c = Conditional(cond_name) + value_ref = input_source.get("value_ref", None) + group_c.value_ref = value_ref group_c.value_ref_in_group = input_source.get_bool("value_ref_in_group", True) value_from = input_source.get("value_from", None) if value_from: value_from = value_from.split(":") temp_value_from = locals().get(value_from[0]) - group_c.test_param = rval[group_c.value_ref] + assert value_ref + group_c.test_param = rval[value_ref] assert isinstance(group_c.test_param, ToolParameter) group_c.test_param.refresh_on_change = True for attr in value_from[1].split("."): @@ -1600,8 +1614,8 @@ def parse_input_elem( group_c.cases.append(case) rval[group_c.name] = group_c elif input_type == "section": - group_s = Section() - group_s.name = input_source.get("name") + section_name = input_source.get("name") + group_s = Section(section_name) group_s.title = input_source.get("title") group_s.help = input_source.get("help", None) group_s.expanded = input_source.get_bool("expanded", False) @@ -1610,8 +1624,8 @@ def parse_input_elem( rval[group_s.name] = group_s elif input_type == "upload_dataset": elem = input_source.elem() - group_u = UploadDataset() - group_u.name = elem.get("name") + upload_name = elem.get("name") + group_u = UploadDataset(upload_name) group_u.title = elem.get("title") group_u.file_type_name = elem.get("file_type_name", group_u.file_type_name) group_u.default_file_type = elem.get("default_file_type", group_u.default_file_type) @@ -1796,22 +1810,22 @@ def visit_inputs(self, values, callback): if self.check_values: visit_input_values(self.inputs, values, callback) - def expand_incoming(self, trans, incoming, request_context, input_format="legacy"): - rerun_remap_job_id = None - if "rerun_remap_job_id" in incoming: - try: - rerun_remap_job_id = trans.app.security.decode_id(incoming["rerun_remap_job_id"]) - except Exception as exception: - log.error(str(exception)) - raise exceptions.MessageException( - "Failure executing tool with id '%s' (attempting to rerun invalid job).", self.id - ) - + def expand_incoming( + self, request_context: WorkRequestContext, incoming: ToolRequestT, input_format: InputFormatT = "legacy" + ) -> Tuple[ + List[ToolStateJobInstancePopulatedT], + List[ToolStateJobInstancePopulatedT], + Optional[int], + Optional[MatchingCollections], + ]: + rerun_remap_job_id = _rerun_remap_job_id(request_context, incoming, self.id) set_dataset_matcher_factory(request_context, self) # Fixed set of input parameters may correspond to any number of jobs. # Expand these out to individual parameters for given jobs (tool executions). - expanded_incomings, collection_info = expand_meta_parameters(trans, self, incoming) + expanded_incomings: List[ToolStateJobInstanceT] + collection_info: Optional[MatchingCollections] + expanded_incomings, collection_info = expand_meta_parameters(request_context, self, incoming) # Remapping a single job to many jobs doesn't make sense, so disable # remap if multi-runs of tools are being used. @@ -1831,39 +1845,11 @@ def expand_incoming(self, trans, incoming, request_context, input_format="legacy "internals.galaxy.tools.validation", "Validated and populated state for tool request", ) - all_errors = [] - all_params = [] + all_errors: List[ParameterValidationErrorsT] = [] + all_params: List[ToolStateJobInstancePopulatedT] = [] + for expanded_incoming in expanded_incomings: - params = {} - errors: Dict[str, str] = {} - if self.input_translator: - self.input_translator.translate(expanded_incoming) - if not self.check_values: - # If `self.check_values` is false we don't do any checking or - # processing on input This is used to pass raw values - # through to/from external sites. - params = expanded_incoming - else: - # Update state for all inputs on the current page taking new - # values from `incoming`. - populate_state( - request_context, - self.inputs, - expanded_incoming, - params, - errors, - simple_errors=False, - input_format=input_format, - ) - # If the tool provides a `validate_input` hook, call it. - validate_input = self.get_hook("validate_input") - if validate_input: - # hooks are so terrible ... this is specifically for https://github.com/galaxyproject/tools-devteam/blob/main/tool_collections/gops/basecoverage/operation_filter.py - legacy_non_dce_params = { - k: v.hda if isinstance(v, model.DatasetCollectionElement) and v.hda else v - for k, v in params.items() - } - validate_input(request_context, errors, legacy_non_dce_params, self.inputs) + params, errors = self._populate(request_context, expanded_incoming, input_format) all_errors.append(errors) all_params.append(params) unset_dataset_matcher_factory(request_context) @@ -1871,14 +1857,74 @@ def expand_incoming(self, trans, incoming, request_context, input_format="legacy log.info(validation_timer) return all_params, all_errors, rerun_remap_job_id, collection_info + def _populate( + self, request_context, expanded_incoming: ToolStateJobInstanceT, input_format: InputFormatT + ) -> Tuple[ToolStateJobInstancePopulatedT, ParameterValidationErrorsT]: + """Validate expanded parameters for a job to replace references with model objects. + + So convert a ToolStateJobInstanceT to a ToolStateJobInstancePopulatedT. + """ + params: ToolStateJobInstancePopulatedT = {} + errors: ParameterValidationErrorsT = {} + if self.input_translator: + self.input_translator.translate(expanded_incoming) + if not self.check_values: + # If `self.check_values` is false we don't do any checking or + # processing on input This is used to pass raw values + # through to/from external sites. + params = cast(ToolStateJobInstancePopulatedT, expanded_incoming) + else: + # Update state for all inputs on the current page taking new + # values from `incoming`. + populate_state( + request_context, + self.inputs, + expanded_incoming, + params, + errors, + simple_errors=False, + input_format=input_format, + ) + # If the tool provides a `validate_input` hook, call it. + validate_input = self.get_hook("validate_input") + if validate_input: + # hooks are so terrible ... this is specifically for https://github.com/galaxyproject/tools-devteam/blob/main/tool_collections/gops/basecoverage/operation_filter.py + legacy_non_dce_params = { + k: v.hda if isinstance(v, model.DatasetCollectionElement) and v.hda else v + for k, v in params.items() + } + validate_input(request_context, errors, legacy_non_dce_params, self.inputs) + return params, errors + + def completed_jobs( + self, trans, use_cached_job: bool, all_params: List[ToolStateJobInstancePopulatedT] + ) -> Dict[int, Optional[model.Job]]: + completed_jobs: Dict[int, Optional[model.Job]] = {} + for i, param in enumerate(all_params): + if use_cached_job: + tool_id = self.id + assert tool_id + param_dump: ToolStateDumpedToJsonInternalT = params_to_json_internal(self.inputs, param, self.app) + completed_jobs[i] = self.job_search.by_tool_input( + trans=trans, + tool_id=tool_id, + tool_version=self.version, + param=param, + param_dump=param_dump, + job_state=None, + ) + else: + completed_jobs[i] = None + return completed_jobs + def handle_input( self, trans, - incoming: ToolParameterRequestT, + incoming: ToolRequestT, history: Optional[model.History] = None, use_cached_job: bool = DEFAULT_USE_CACHED_JOB, preferred_object_store_id: Optional[str] = DEFAULT_PREFERRED_OBJECT_STORE_ID, - input_format: str = "legacy", + input_format: InputFormatT = "legacy", ): """ Process incoming parameters for this tool from the dict `incoming`, @@ -1887,26 +1933,17 @@ def handle_input( there were no errors). """ request_context = proxy_work_context_for_history(trans, history=history) - all_params, all_errors, rerun_remap_job_id, collection_info = self.expand_incoming( - trans=trans, incoming=incoming, request_context=request_context, input_format=input_format - ) + expanded = self.expand_incoming(request_context, incoming=incoming, input_format=input_format) + all_params: List[ToolStateJobInstancePopulatedT] = expanded[0] + all_errors: List[ParameterValidationErrorsT] = expanded[1] + rerun_remap_job_id: Optional[int] = expanded[2] + collection_info: Optional[MatchingCollections] = expanded[3] + # If there were errors, we stay on the same page and display them self.handle_incoming_errors(all_errors) mapping_params = MappingParameters(incoming, all_params) - completed_jobs: Dict[int, Optional[model.Job]] = {} - for i, param in enumerate(all_params): - if use_cached_job: - completed_jobs[i] = self.job_search.by_tool_input( - trans=trans, - tool_id=self.id, - tool_version=self.version, - param=param, - param_dump=self.params_to_strings(param, self.app, nested=True), - job_state=None, - ) - else: - completed_jobs[i] = None + completed_jobs: Dict[int, Optional[model.Job]] = self.completed_jobs(trans, use_cached_job, all_params) execution_tracker = execute_job( trans, self, @@ -1935,7 +1972,7 @@ def handle_input( implicit_collections=execution_tracker.implicit_collections, ) - def handle_incoming_errors(self, all_errors): + def handle_incoming_errors(self, all_errors: List[ParameterValidationErrorsT]) -> None: if any(all_errors): # simple param_key -> message string for tool form. err_data = {key: unicodify(value) for d in all_errors for (key, value) in d.items()} @@ -2060,7 +2097,7 @@ def get_static_param_values(self, trans): def execute( self, trans, - incoming: Optional[ToolParameterRequestInstanceT] = None, + incoming: Optional[ToolStateJobInstancePopulatedT] = None, history: Optional[model.History] = None, set_output_hid: bool = DEFAULT_SET_OUTPUT_HID, flush_job: bool = True, @@ -2086,7 +2123,7 @@ def execute( def _execute( self, trans, - incoming: Optional[ToolParameterRequestInstanceT] = None, + incoming: Optional[ToolStateJobInstancePopulatedT] = None, history: Optional[model.History] = None, rerun_remap_job_id: Optional[int] = DEFAULT_RERUN_REMAP_JOB_ID, execution_cache: Optional[ToolExecutionCache] = None, @@ -2128,7 +2165,7 @@ def _execute( log.error("Tool execution failed for job: %s", job_id) raise - def params_to_strings(self, params, app, nested=False): + def params_to_strings(self, params: ToolStateJobInstancePopulatedT, app, nested=False): return params_to_strings(self.inputs, params, app, nested) def params_from_strings(self, params, app, ignore_errors=False): @@ -2559,7 +2596,7 @@ def to_json(self, trans, kwd=None, job=None, workflow_building_mode=False, histo set_dataset_matcher_factory(request_context, self) # create tool state state_inputs: Dict[str, str] = {} - state_errors: Dict[str, str] = {} + state_errors: ParameterValidationErrorsT = {} populate_state(request_context, self.inputs, params.__dict__, state_inputs, state_errors) # create tool model @@ -2581,6 +2618,8 @@ def to_json(self, trans, kwd=None, job=None, workflow_building_mode=False, histo else: action = self.app.url_for(self.action) + state_inputs_json: ToolStateDumpedToJsonT = params_to_json(self.inputs, state_inputs, self.app) + # update tool model tool_model.update( { @@ -2594,7 +2633,7 @@ def to_json(self, trans, kwd=None, job=None, workflow_building_mode=False, histo "requirements": [{"name": r.name, "version": r.version} for r in self.requirements], "errors": state_errors, "tool_errors": self.tool_errors, - "state_inputs": params_to_strings(self.inputs, state_inputs, self.app, use_security=True, nested=True), + "state_inputs": state_inputs_json, "job_id": trans.security.encode_id(job.id) if job else None, "job_remap": job.remappable() if job else None, "history_id": trans.security.encode_id(history.id) if history else None, @@ -4110,6 +4149,21 @@ def produce_outputs(self, trans, out_data, output_collections, incoming, history # ---- Utility classes to be factored out ----------------------------------- + + +def _rerun_remap_job_id(trans, incoming, tool_id: Optional[str]) -> Optional[int]: + rerun_remap_job_id = None + if "rerun_remap_job_id" in incoming: + try: + rerun_remap_job_id = trans.app.security.decode_id(incoming["rerun_remap_job_id"]) + except Exception as exception: + log.error(str(exception)) + raise exceptions.MessageException( + "Failure executing tool with id '%s' (attempting to rerun invalid job).", tool_id + ) + return rerun_remap_job_id + + class TracksterConfig: """Trackster configuration encapsulation.""" diff --git a/lib/galaxy/tools/_types.py b/lib/galaxy/tools/_types.py new file mode 100644 index 000000000000..635a86cf459d --- /dev/null +++ b/lib/galaxy/tools/_types.py @@ -0,0 +1,65 @@ +""" +Tool state goes through several different iterations that are difficult to follow I think, +hopefully datatypes can be used as markers to describe what has been done - even if they don't +provide strong traditional typing semantics. + ++--------------------------------+------------+---------------------------------+------------+-----------+ +| Python Type | State for? | Object References | Validated? | xref | ++================================+============+=================================+============+===========+ +| ToolRequestT | request | src dicts of encoded ids | nope | | +| ToolStateJobInstanceT | a job | src dicts of encoded ids | nope | | +| ToolStateJobInstancePopulatedT | a job | model objs loaded from db | check_param | | +| ToolStateDumpedToJsonT | a job | src dicts of encoded ids | " | | +| | | (normalized into values attr) | " | | +| ToolStateDumpedToJsonInternalT | a job | src dicts of decoded ids | " | | +| | | (normalized into values attr) | " | | +| ToolStateDumpedToStringsT | a job | src dicts dumped to strs | " | | +| | | (normalized into values attr) | " | | ++--------------------------------+------------+---------------------------------+-------------+----------+ +""" + +from typing import ( + Any, + Dict, + Union, +) + +from typing_extensions import Literal + +# Input dictionary from the API, may include map/reduce instructions. Objects are referenced by "src" +# dictionaries and encoded IDS. +ToolRequestT = Dict[str, Any] + +# Input dictionary extracted from a tool request for running a tool individually as a single job. Objects are referenced +# by "src" dictionaries with encoded IDs still but batch instructions have been pulled out. Parameters have not +# been "checked" (check_param has not been called). +ToolStateJobInstanceT = Dict[str, Any] + +# Input dictionary for an individual job where objects are their model objects and parameters have been +# "checked" (check_param has been called). +ToolStateJobInstancePopulatedT = Dict[str, Any] + +# Input dictionary for an individual where the state has been valiated and populated but then converted back down +# to json. Object references are unified in the format of {"values": List["src" dictionary]} where the src dictionaries. +# are decoded ids (ints). +# See comments on galaxy.tools.parameters.params_to_strings for more information. +ToolStateDumpedToJsonInternalT = Dict[str, Any] + +# Input dictionary for an individual where the state has been valiated and populated but then converted back down +# to json. Object references are unified in the format of {"values": List["src" dictionary]} where src dictonaries +# are encoded (ids). See comments on galaxy.tools.parameters.params_to_strings for more information. +ToolStateDumpedToJsonT = Dict[str, Any] + +# Input dictionary for an individual where the state has been valiated and populated but then converted back down +# to json. Object references are unified in the format of {"values": List["src" dictionary]} but dumped into +# strings. See comments on galaxy.tools.parameters.params_to_strings for more information. This maybe should be +# broken into separate types for encoded and decoded IDs in subsequent type refinements if both are used, it not +# this comment should be updated to indicate which is used exclusively. +ToolStateDumpedToStringsT = Dict[str, str] + +# A dictionary of error messages that occur while attempting to validate a ToolStateJobInstanceT and transform it +# into a ToolStateJobInstancePopulatedT with model objects populated. Tool errors indicate the job should not be +# further processed. +ParameterValidationErrorsT = Dict[str, Union["ParameterValidationErrorsT", str, Exception]] + +InputFormatT = Literal["legacy", "21.01"] diff --git a/lib/galaxy/tools/actions/__init__.py b/lib/galaxy/tools/actions/__init__.py index 989ab5a6ad45..bddc224e99ae 100644 --- a/lib/galaxy/tools/actions/__init__.py +++ b/lib/galaxy/tools/actions/__init__.py @@ -37,6 +37,7 @@ from galaxy.model.dataset_collections.matching import MatchingCollections from galaxy.model.none_like import NoneDataset from galaxy.objectstore import ObjectStorePopulator +from galaxy.tools._types import ToolStateJobInstancePopulatedT from galaxy.tools.execute import ( DatasetCollectionElementsSliceT, DEFAULT_DATASET_COLLECTION_ELEMENTS, @@ -45,7 +46,6 @@ DEFAULT_RERUN_REMAP_JOB_ID, DEFAULT_SET_OUTPUT_HID, JobCallbackT, - ToolParameterRequestInstanceT, ) from galaxy.tools.execution_helpers import ( filter_output, @@ -88,7 +88,7 @@ def execute( self, tool, trans, - incoming: Optional[ToolParameterRequestInstanceT] = None, + incoming: Optional[ToolStateJobInstancePopulatedT] = None, history: Optional[History] = None, job_params=None, rerun_remap_job_id: Optional[int] = DEFAULT_RERUN_REMAP_JOB_ID, @@ -104,6 +104,21 @@ def execute( ) -> ToolActionExecuteResult: """Perform target tool action.""" + @abstractmethod + def get_output_name( + self, + output, + dataset=None, + tool=None, + on_text=None, + trans=None, + incoming=None, + history=None, + params=None, + job_params=None, + ) -> str: + """Get name to assign a tool output.""" + class DefaultToolAction(ToolAction): """Default tool action is to run an external command""" @@ -401,7 +416,7 @@ def execute( self, tool, trans, - incoming: Optional[ToolParameterRequestInstanceT] = None, + incoming: Optional[ToolStateJobInstancePopulatedT] = None, history: Optional[History] = None, job_params=None, rerun_remap_job_id: Optional[int] = DEFAULT_RERUN_REMAP_JOB_ID, @@ -950,7 +965,7 @@ def get_output_name( history=None, params=None, job_params=None, - ): + ) -> str: if output.label: params["tool"] = tool params["on_string"] = on_text diff --git a/lib/galaxy/tools/actions/data_manager.py b/lib/galaxy/tools/actions/data_manager.py index c24e86fd0afb..d8786ad5a921 100644 --- a/lib/galaxy/tools/actions/data_manager.py +++ b/lib/galaxy/tools/actions/data_manager.py @@ -7,6 +7,7 @@ ) from galaxy.model.base import transaction from galaxy.model.dataset_collections.matching import MatchingCollections +from galaxy.tools._types import ToolStateJobInstancePopulatedT from galaxy.tools.execute import ( DatasetCollectionElementsSliceT, DEFAULT_DATASET_COLLECTION_ELEMENTS, @@ -15,7 +16,6 @@ DEFAULT_RERUN_REMAP_JOB_ID, DEFAULT_SET_OUTPUT_HID, JobCallbackT, - ToolParameterRequestInstanceT, ) from galaxy.tools.execution_helpers import ToolExecutionCache from . import ( @@ -33,7 +33,7 @@ def execute( self, tool, trans, - incoming: Optional[ToolParameterRequestInstanceT] = None, + incoming: Optional[ToolStateJobInstancePopulatedT] = None, history: Optional[History] = None, job_params=None, rerun_remap_job_id: Optional[int] = DEFAULT_RERUN_REMAP_JOB_ID, diff --git a/lib/galaxy/tools/actions/history_imp_exp.py b/lib/galaxy/tools/actions/history_imp_exp.py index 848995c61dac..de502f3a11d7 100644 --- a/lib/galaxy/tools/actions/history_imp_exp.py +++ b/lib/galaxy/tools/actions/history_imp_exp.py @@ -11,6 +11,7 @@ ) from galaxy.model.base import transaction from galaxy.model.dataset_collections.matching import MatchingCollections +from galaxy.tools._types import ToolStateJobInstancePopulatedT from galaxy.tools.actions import ( ToolAction, ToolActionExecuteResult, @@ -23,7 +24,6 @@ DEFAULT_RERUN_REMAP_JOB_ID, DEFAULT_SET_OUTPUT_HID, JobCallbackT, - ToolParameterRequestInstanceT, ) from galaxy.tools.execution_helpers import ToolExecutionCache from galaxy.tools.imp_exp import ( @@ -44,7 +44,7 @@ def execute( self, tool, trans, - incoming: Optional[ToolParameterRequestInstanceT] = None, + incoming: Optional[ToolStateJobInstancePopulatedT] = None, history: Optional[History] = None, job_params=None, rerun_remap_job_id: Optional[int] = DEFAULT_RERUN_REMAP_JOB_ID, @@ -121,7 +121,7 @@ def execute( self, tool, trans, - incoming: Optional[ToolParameterRequestInstanceT] = None, + incoming: Optional[ToolStateJobInstancePopulatedT] = None, history: Optional[History] = None, job_params=None, rerun_remap_job_id: Optional[int] = DEFAULT_RERUN_REMAP_JOB_ID, diff --git a/lib/galaxy/tools/actions/metadata.py b/lib/galaxy/tools/actions/metadata.py index f7d6ce844a9d..2b46c6060ccd 100644 --- a/lib/galaxy/tools/actions/metadata.py +++ b/lib/galaxy/tools/actions/metadata.py @@ -16,6 +16,7 @@ ) from galaxy.model.base import transaction from galaxy.model.dataset_collections.matching import MatchingCollections +from galaxy.tools._types import ToolStateJobInstancePopulatedT from galaxy.tools.execute import ( DatasetCollectionElementsSliceT, DEFAULT_DATASET_COLLECTION_ELEMENTS, @@ -24,7 +25,6 @@ DEFAULT_RERUN_REMAP_JOB_ID, DEFAULT_SET_OUTPUT_HID, JobCallbackT, - ToolParameterRequestInstanceT, ) from galaxy.tools.execution_helpers import ToolExecutionCache from galaxy.util import asbool @@ -43,7 +43,7 @@ def execute( self, tool, trans, - incoming: Optional[ToolParameterRequestInstanceT] = None, + incoming: Optional[ToolStateJobInstancePopulatedT] = None, history: Optional[History] = None, job_params=None, rerun_remap_job_id: Optional[int] = DEFAULT_RERUN_REMAP_JOB_ID, diff --git a/lib/galaxy/tools/actions/model_operations.py b/lib/galaxy/tools/actions/model_operations.py index 1b18adcf39f6..fdfecfb9c65e 100644 --- a/lib/galaxy/tools/actions/model_operations.py +++ b/lib/galaxy/tools/actions/model_operations.py @@ -10,6 +10,7 @@ ) from galaxy.model.dataset_collections.matching import MatchingCollections from galaxy.objectstore import ObjectStorePopulator +from galaxy.tools._types import ToolStateJobInstancePopulatedT from galaxy.tools.actions import ( DefaultToolAction, OutputCollections, @@ -24,7 +25,6 @@ DEFAULT_RERUN_REMAP_JOB_ID, DEFAULT_SET_OUTPUT_HID, JobCallbackT, - ToolParameterRequestInstanceT, ) from galaxy.tools.execution_helpers import ToolExecutionCache @@ -52,7 +52,7 @@ def execute( self, tool, trans, - incoming: Optional[ToolParameterRequestInstanceT] = None, + incoming: Optional[ToolStateJobInstancePopulatedT] = None, history: Optional[History] = None, job_params=None, rerun_remap_job_id: Optional[int] = DEFAULT_RERUN_REMAP_JOB_ID, diff --git a/lib/galaxy/tools/actions/upload.py b/lib/galaxy/tools/actions/upload.py index b85bba71a0d2..c758c96b2ae7 100644 --- a/lib/galaxy/tools/actions/upload.py +++ b/lib/galaxy/tools/actions/upload.py @@ -11,6 +11,7 @@ from galaxy.model.base import transaction from galaxy.model.dataset_collections.matching import MatchingCollections from galaxy.model.dataset_collections.structure import UninitializedTree +from galaxy.tools._types import ToolStateJobInstancePopulatedT from galaxy.tools.actions import upload_common from galaxy.tools.execute import ( DatasetCollectionElementsSliceT, @@ -20,7 +21,6 @@ DEFAULT_RERUN_REMAP_JOB_ID, DEFAULT_SET_OUTPUT_HID, JobCallbackT, - ToolParameterRequestInstanceT, ) from galaxy.tools.execution_helpers import ToolExecutionCache from galaxy.util import ExecutionTimer @@ -40,7 +40,7 @@ def execute( self, tool, trans, - incoming: Optional[ToolParameterRequestInstanceT] = None, + incoming: Optional[ToolStateJobInstancePopulatedT] = None, history: Optional[History] = None, job_params=None, rerun_remap_job_id: Optional[int] = DEFAULT_RERUN_REMAP_JOB_ID, diff --git a/lib/galaxy/tools/execute.py b/lib/galaxy/tools/execute.py index aef512f61f37..2ad41a75493b 100644 --- a/lib/galaxy/tools/execute.py +++ b/lib/galaxy/tools/execute.py @@ -15,6 +15,8 @@ List, NamedTuple, Optional, + Tuple, + Union, ) from boltons.iterutils import remap @@ -35,6 +37,10 @@ ToolExecutionCache, ) from galaxy.tools.parameters.workflow_utils import is_runtime_value +from ._types import ( + ToolRequestT, + ToolStateJobInstancePopulatedT, +) if typing.TYPE_CHECKING: from galaxy.tools import Tool @@ -48,10 +54,6 @@ CompletedJobsT = Dict[int, Optional[model.Job]] JobCallbackT = Callable WorkflowResourceParametersT = Dict[str, Any] -# Input dictionary from the API, may include map/reduce instructions -ToolParameterRequestT = Dict[str, Any] -# Input dictionary extracted from a tool request for running a tool individually -ToolParameterRequestInstanceT = Dict[str, Any] DatasetCollectionElementsSliceT = Dict[str, model.DatasetCollectionElement] DEFAULT_USE_CACHED_JOB = False DEFAULT_PREFERRED_OBJECT_STORE_ID: Optional[str] = None @@ -67,8 +69,8 @@ def __init__(self, execution_tracker: "ExecutionTracker"): class MappingParameters(NamedTuple): - param_template: ToolParameterRequestT - param_combinations: List[ToolParameterRequestInstanceT] + param_template: ToolRequestT + param_combinations: List[ToolStateJobInstancePopulatedT] def execute( @@ -242,14 +244,14 @@ def execute_single_job(execution_slice: "ExecutionSlice", completed_job: Optiona class ExecutionSlice: job_index: int - param_combination: ToolParameterRequestInstanceT + param_combination: ToolStateJobInstancePopulatedT dataset_collection_elements: Optional[DatasetCollectionElementsSliceT] history: Optional[model.History] def __init__( self, job_index: int, - param_combination: ToolParameterRequestInstanceT, + param_combination: ToolStateJobInstancePopulatedT, dataset_collection_elements: Optional[DatasetCollectionElementsSliceT] = DEFAULT_DATASET_COLLECTION_ELEMENTS, ): self.job_index = job_index @@ -258,7 +260,16 @@ def __init__( self.history = None +ExecutionErrorsT = Union[str, Exception] + + class ExecutionTracker: + execution_errors: List[ExecutionErrorsT] + successful_jobs: List[model.Job] + output_datasets: List[model.HistoryDatasetAssociation] + output_collections: List[Tuple[str, model.HistoryDatasetCollectionAssociation]] + implicit_collections: Dict[str, model.HistoryDatasetCollectionAssociation] + def __init__( self, trans, @@ -312,8 +323,9 @@ def record_error(self, error): @property def on_text(self): - if self._on_text is None: - collection_names = ["collection %d" % c.hid for c in self.collection_info.collections.values()] + collection_info = self.collection_info + if self._on_text is None and collection_info is not None: + collection_names = ["collection %d" % c.hid for c in collection_info.collections.values()] self._on_text = on_text_for_names(collection_names) return self._on_text @@ -371,19 +383,23 @@ def find_collection(input_dict, input_name): ) subcollection_mapping_type = None if self.is_implicit_input(input_name): - subcollection_mapping_type = self.collection_info.subcollection_mapping_type(input_name) + collection_info = self.collection_info + assert collection_info + subcollection_mapping_type = collection_info.subcollection_mapping_type(input_name) return get_structure( input_collection, collection_type_description, leaf_subcollection_type=subcollection_mapping_type ) def _structure_for_output(self, trans, tool_output): - structure = self.collection_info.structure + collection_info = self.collection_info + assert collection_info + structure = collection_info.structure if hasattr(tool_output, "default_identifier_source"): # Switch the structure for outputs if the output specified a default_identifier_source collection_type_descriptions = trans.app.dataset_collection_manager.collection_type_descriptions - source_collection = self.collection_info.collections.get(tool_output.default_identifier_source) + source_collection = collection_info.collections.get(tool_output.default_identifier_source) if source_collection: collection_type_description = collection_type_descriptions.for_collection_type( source_collection.collection.collection_type @@ -423,10 +439,10 @@ def precreate_output_collections(self, history, params): # collection replaced with a specific dataset. Need to replace this # with the collection and wrap everything up so can evaluate output # label. + collection_info = self.collection_info + assert collection_info trans = self.trans - params.update( - self.collection_info.collections - ) # Replace datasets with source collections for labelling outputs. + params.update(collection_info.collections) # Replace datasets with source collections for labelling outputs. collection_instances = {} implicit_inputs = self.implicit_inputs @@ -498,6 +514,7 @@ def finalize_dataset_collections(self, trans): completed_collections = {} if ( self.completed_jobs + and self.completed_jobs[0] and self.implicit_collection_jobs and len(self.completed_jobs) == len(self.successful_jobs) ): @@ -518,13 +535,16 @@ def finalize_dataset_collections(self, trans): implicit_collection_jobs = implicit_collection.implicit_collection_jobs implicit_collection_jobs.populated_state = "ok" trans.sa_session.add(implicit_collection_jobs) + collection_info = self.collection_info + assert collection_info implicit_collection.collection.finalize( - collection_type_description=self.collection_info.structure.collection_type_description + collection_type_description=collection_info.structure.collection_type_description ) # Mark implicit HDCA as copied - completed_implicit_collection = implicit_collection and completed_collections.get( - implicit_collection.implicit_output_name + implicit_output_name = implicit_collection.implicit_output_name + completed_implicit_collection = ( + implicit_collection and implicit_output_name and completed_collections.get(implicit_output_name) ) if completed_implicit_collection: implicit_collection.copied_from_history_dataset_collection_association_id = ( @@ -538,14 +558,20 @@ def finalize_dataset_collections(self, trans): @property def implicit_inputs(self): - implicit_inputs = list(self.collection_info.collections.items()) + collection_info = self.collection_info + assert collection_info + implicit_inputs = list(collection_info.collections.items()) return implicit_inputs def is_implicit_input(self, input_name): - return input_name in self.collection_info.collections + collection_info = self.collection_info + assert collection_info + return input_name in collection_info.collections def walk_implicit_collections(self): - return self.collection_info.structure.walk_collections(self.implicit_collections) + collection_info = self.collection_info + assert collection_info + return collection_info.structure.walk_collections(self.implicit_collections) def new_execution_slices(self): if self.collection_info is None: @@ -594,7 +620,9 @@ def __init__( # New to track these things for tool output API response in the tool case, # in the workflow case we just write stuff to the database and forget about # it. - self.outputs_by_output_name = collections.defaultdict(list) + self.outputs_by_output_name: Dict[str, List[Union[model.DatasetInstance, model.DatasetCollection]]] = ( + collections.defaultdict(list) + ) def record_success(self, execution_slice, job, outputs): super().record_success(execution_slice, job, outputs) diff --git a/lib/galaxy/tools/parameters/__init__.py b/lib/galaxy/tools/parameters/__init__.py index 7e72bf3a9eb1..a20cef1f9fb0 100644 --- a/lib/galaxy/tools/parameters/__init__.py +++ b/lib/galaxy/tools/parameters/__init__.py @@ -4,7 +4,9 @@ from json import dumps from typing import ( + cast, Dict, + Optional, Union, ) @@ -32,12 +34,23 @@ runtime_to_json, ) from .wrapped import flat_to_nested_state +from .._types import ( + InputFormatT, + ParameterValidationErrorsT, + ToolStateDumpedToJsonInternalT, + ToolStateDumpedToJsonT, + ToolStateDumpedToStringsT, + ToolStateJobInstancePopulatedT, + ToolStateJobInstanceT, +) REPLACE_ON_TRUTHY = object() # Some tools use the code tag and access the code base, expecting certain tool parameters to be available here. __all__ = ("DataCollectionToolParameter", "DataToolParameter", "SelectToolParameter") +ToolInputsT = Dict[str, Union[Group, ToolParameter]] + def visit_input_values( inputs, @@ -253,15 +266,37 @@ def check_param(trans, param, incoming_value, param_values, simple_errors=True): return value, error +def params_to_json_internal( + params: ToolInputsT, param_values: ToolStateJobInstancePopulatedT, app +) -> ToolStateDumpedToJsonInternalT: + """Return ToolStateDumpedToJsonT for supplied validated and populated parameters.""" + return cast( + ToolStateDumpedToJsonInternalT, params_to_strings(params, param_values, app, nested=True, use_security=False) + ) + + +def params_to_json(params: ToolInputsT, param_values: ToolStateJobInstancePopulatedT, app) -> ToolStateDumpedToJsonT: + """Return ToolStateDumpedToJsonT for supplied validated and populated parameters.""" + return cast(ToolStateDumpedToJsonT, params_to_strings(params, param_values, app, nested=True, use_security=True)) + + def params_to_strings( - params: Dict[str, Union[Group, ToolParameter]], param_values: Dict, app, nested=False, use_security=False -) -> Dict: + params: ToolInputsT, + param_values: ToolStateJobInstancePopulatedT, + app, + nested=False, + use_security=False, +) -> Union[ToolStateDumpedToJsonT, ToolStateDumpedToJsonInternalT, ToolStateDumpedToStringsT]: """ Convert a dictionary of parameter values to a dictionary of strings suitable for persisting. The `value_to_basic` method of each parameter is called to convert its value to basic types, the result of which is then json encoded (this allowing complex nested parameters and - such). + such). If `nested` this will remain as a sort of JSON-ifiable dictionary + (ToolStateDumpedToJsonT), otherwise these will dumped into strings of the + JSON (ToolStateDumpedToStringsT). If use_security is False, this will return + object references with decoded (integer) IDs, otherwise they will be encoded + strings. """ rval = {} for key, value in param_values.items(): @@ -344,14 +379,14 @@ def replace_dataset_ids(path, key, value): def populate_state( request_context, - inputs, - incoming, - state, - errors=None, + inputs: ToolInputsT, + incoming: ToolStateJobInstanceT, + state: ToolStateJobInstancePopulatedT, + errors: Optional[ParameterValidationErrorsT] = None, context=None, check=True, simple_errors=True, - input_format="legacy", + input_format: InputFormatT = "legacy", ): """ Populates nested state dict from incoming parameter values. @@ -426,77 +461,84 @@ def populate_state( state[input.name] = input.get_initial_value(request_context, context) group_state = state[input.name] if input.type == "repeat": - if len(incoming[input.name]) > input.max or len(incoming[input.name]) < input.min: - errors[input.name] = "The number of repeat elements is outside the range specified by the tool." + repeat_input = cast(Repeat, input) + if ( + len(incoming[repeat_input.name]) > repeat_input.max + or len(incoming[repeat_input.name]) < repeat_input.min + ): + errors[repeat_input.name] = ( + "The number of repeat elements is outside the range specified by the tool." + ) else: del group_state[:] - for rep in incoming[input.name]: - new_state = {} + for rep in incoming[repeat_input.name]: + new_state: ToolStateJobInstancePopulatedT = {} group_state.append(new_state) - new_errors = {} + repeat_errors: ParameterValidationErrorsT = {} populate_state( request_context, - input.inputs, + repeat_input.inputs, rep, new_state, - new_errors, + repeat_errors, context=context, check=check, simple_errors=simple_errors, input_format=input_format, ) - if new_errors: - errors[input.name] = new_errors + if repeat_errors: + errors[repeat_input.name] = repeat_errors elif input.type == "conditional": - test_param_value = incoming.get(input.name, {}).get(input.test_param.name) + conditional_input = cast(Conditional, input) + test_param = cast(ToolParameter, conditional_input.test_param) + test_param_value = incoming.get(conditional_input.name, {}).get(test_param.name) value, error = ( - check_param( - request_context, input.test_param, test_param_value, context, simple_errors=simple_errors - ) + check_param(request_context, test_param, test_param_value, context, simple_errors=simple_errors) if check else [test_param_value, None] ) if error: - errors[input.test_param.name] = error + errors[test_param.name] = error else: try: - current_case = input.get_current_case(value) - group_state = state[input.name] = {} - new_errors = {} + current_case = conditional_input.get_current_case(value) + group_state = state[conditional_input.name] = {} + cast_errors: ParameterValidationErrorsT = {} populate_state( request_context, - input.cases[current_case].inputs, - incoming.get(input.name), + conditional_input.cases[current_case].inputs, + cast(ToolStateJobInstanceT, incoming.get(conditional_input.name)), group_state, - new_errors, + cast_errors, context=context, check=check, simple_errors=simple_errors, input_format=input_format, ) - if new_errors: - errors[input.name] = new_errors + if cast_errors: + errors[conditional_input.name] = cast_errors group_state["__current_case__"] = current_case except Exception: - errors[input.test_param.name] = "The selected case is unavailable/invalid." - group_state[input.test_param.name] = value + errors[test_param.name] = "The selected case is unavailable/invalid." + group_state[test_param.name] = value elif input.type == "section": - new_errors = {} + section_input = cast(Section, input) + section_errors: ParameterValidationErrorsT = {} populate_state( request_context, - input.inputs, - incoming.get(input.name), + section_input.inputs, + cast(ToolStateJobInstanceT, incoming.get(section_input.name)), group_state, - new_errors, + section_errors, context=context, check=check, simple_errors=simple_errors, input_format=input_format, ) - if new_errors: - errors[input.name] = new_errors + if section_errors: + errors[section_input.name] = section_errors elif input.type == "upload_dataset": raise NotImplementedError @@ -516,7 +558,15 @@ def populate_state( def _populate_state_legacy( - request_context, inputs, incoming, state, errors, prefix="", context=None, check=True, simple_errors=True + request_context, + inputs: ToolInputsT, + incoming: ToolStateJobInstanceT, + state: ToolStateJobInstancePopulatedT, + errors, + prefix="", + context=None, + check=True, + simple_errors=True, ): if context is None: context = flat_to_nested_state(incoming) @@ -527,22 +577,23 @@ def _populate_state_legacy( group_state = state[input.name] group_prefix = f"{key}|" if input.type == "repeat": + repeat_input = cast(Repeat, input) rep_index = 0 del group_state[:] while True: rep_prefix = "%s_%d" % (key, rep_index) - rep_min_default = input.default if input.default > input.min else input.min + rep_min_default = repeat_input.default if repeat_input.default > repeat_input.min else repeat_input.min if ( not any(incoming_key.startswith(rep_prefix) for incoming_key in incoming.keys()) and rep_index >= rep_min_default ): break - if rep_index < input.max: - new_state = {"__index__": rep_index} + if rep_index < repeat_input.max: + new_state: ToolStateJobInstancePopulatedT = {"__index__": rep_index} group_state.append(new_state) _populate_state_legacy( request_context, - input.inputs, + repeat_input.inputs, incoming, new_state, errors, @@ -553,13 +604,21 @@ def _populate_state_legacy( ) rep_index += 1 elif input.type == "conditional": - if input.value_ref and not input.value_ref_in_group: - test_param_key = prefix + input.test_param.name + conditional_input = cast(Conditional, input) + test_param = cast(ToolParameter, conditional_input.test_param) + if conditional_input.value_ref and not conditional_input.value_ref_in_group: + test_param_key = prefix + test_param.name else: - test_param_key = group_prefix + input.test_param.name - test_param_value = incoming.get(test_param_key, group_state.get(input.test_param.name)) + test_param_key = group_prefix + test_param.name + test_param_value = incoming.get(test_param_key, group_state.get(test_param.name)) value, error = ( - check_param(request_context, input.test_param, test_param_value, context, simple_errors=simple_errors) + check_param( + request_context, + test_param, + test_param_value, + context, + simple_errors=simple_errors, + ) if check else [test_param_value, None] ) @@ -567,11 +626,11 @@ def _populate_state_legacy( errors[test_param_key] = error else: try: - current_case = input.get_current_case(value) - group_state = state[input.name] = {} + current_case = conditional_input.get_current_case(value) + group_state = state[conditional_input.name] = cast(ToolStateJobInstancePopulatedT, {}) _populate_state_legacy( request_context, - input.cases[current_case].inputs, + conditional_input.cases[current_case].inputs, incoming, group_state, errors, @@ -583,11 +642,12 @@ def _populate_state_legacy( group_state["__current_case__"] = current_case except Exception: errors[test_param_key] = "The selected case is unavailable/invalid." - group_state[input.test_param.name] = value + group_state[test_param.name] = value elif input.type == "section": + section_input = cast(Section, input) _populate_state_legacy( request_context, - input.inputs, + section_input.inputs, incoming, group_state, errors, @@ -597,20 +657,21 @@ def _populate_state_legacy( simple_errors=simple_errors, ) elif input.type == "upload_dataset": - file_count = input.get_file_count(request_context, context) + dataset_input = cast(UploadDataset, input) + file_count = dataset_input.get_file_count(request_context, context) while len(group_state) > file_count: del group_state[-1] while file_count > len(group_state): - new_state = {"__index__": len(group_state)} - for upload_item in input.inputs.values(): - new_state[upload_item.name] = upload_item.get_initial_value(request_context, context) - group_state.append(new_state) + new_state_upload: ToolStateJobInstancePopulatedT = {"__index__": len(group_state)} + for upload_item in dataset_input.inputs.values(): + new_state_upload[upload_item.name] = upload_item.get_initial_value(request_context, context) + group_state.append(new_state_upload) for rep_index, rep_state in enumerate(group_state): rep_index = rep_state.get("__index__", rep_index) rep_prefix = "%s_%d|" % (key, rep_index) _populate_state_legacy( request_context, - input.inputs, + dataset_input.inputs, incoming, rep_state, errors, diff --git a/lib/galaxy/tools/parameters/basic.py b/lib/galaxy/tools/parameters/basic.py index 5eeba8aaecaa..9669b62771e9 100644 --- a/lib/galaxy/tools/parameters/basic.py +++ b/lib/galaxy/tools/parameters/basic.py @@ -174,6 +174,7 @@ class ToolParameter(UsesDictVisibleKeys): >>> assert sorted(p.to_dict(trans).items()) == [('argument', '--parameter-name'), ('help', ''), ('hidden', False), ('is_dynamic', False), ('label', ''), ('model_class', 'ToolParameter'), ('name', 'parameter_name'), ('optional', False), ('refresh_on_change', False), ('type', 'text'), ('value', None)] """ + name: str dict_collection_visible_keys = ["name", "argument", "type", "label", "help", "refresh_on_change"] def __init__(self, tool, input_source, context=None): diff --git a/lib/galaxy/tools/parameters/grouping.py b/lib/galaxy/tools/parameters/grouping.py index 9962e61cf624..26b4e171c0de 100644 --- a/lib/galaxy/tools/parameters/grouping.py +++ b/lib/galaxy/tools/parameters/grouping.py @@ -6,6 +6,7 @@ import logging import os import unicodedata +from math import inf from typing import ( Any, Callable, @@ -35,6 +36,7 @@ if TYPE_CHECKING: from galaxy.tools import Tool from galaxy.tools.parameter.basic import ToolParameter + from galaxy.tools.parameters import ToolInputsT log = logging.getLogger(__name__) URI_PREFIXES = [ @@ -58,9 +60,10 @@ class Group(UsesDictVisibleKeys): dict_collection_visible_keys = ["name", "type"] type: str + name: str - def __init__(self): - self.name = None + def __init__(self, name: str): + self.name = name @property def visible(self): @@ -94,15 +97,18 @@ def to_dict(self, trans): class Repeat(Group): dict_collection_visible_keys = ["name", "type", "title", "help", "default", "min", "max"] type = "repeat" + inputs: "ToolInputsT" + min: int + max: float - def __init__(self): - Group.__init__(self) + def __init__(self, name: str): + Group.__init__(self, name) self._title = None - self.inputs = None + self.inputs = {} self.help = None self.default = 0 - self.min = None - self.max = None + self.min = 0 + self.max = inf @property def title(self): @@ -186,11 +192,12 @@ def input_to_dict(input): class Section(Group): dict_collection_visible_keys = ["name", "type", "title", "help", "expanded"] type = "section" + inputs: "ToolInputsT" - def __init__(self): - Group.__init__(self) + def __init__(self, name: str): + Group.__init__(self, name) self.title = None - self.inputs = None + self.inputs = {} self.help = None self.expanded = False @@ -266,11 +273,12 @@ class Dataset(Bunch): class UploadDataset(Group): type = "upload_dataset" + inputs: "ToolInputsT" - def __init__(self): - Group.__init__(self) + def __init__(self, name: str): + Group.__init__(self, name) self.title = None - self.inputs = None + self.inputs = {} self.file_type_name = "file_type" self.default_file_type = "txt" self.file_type_to_ext = {"auto": self.default_file_type} @@ -735,12 +743,13 @@ def get_filenames(context): class Conditional(Group): type = "conditional" value_from: Callable[[ExpressionContext, "Conditional", "Tool"], Mapping[str, str]] + cases: List["ConditionalWhen"] - def __init__(self): - Group.__init__(self) + def __init__(self, name: str): + Group.__init__(self, name) self.test_param: Optional[ToolParameter] = None self.cases = [] - self.value_ref = None + self.value_ref: Optional[str] = None self.value_ref_in_group = True # When our test_param is not part of the conditional Group, this is False @property @@ -761,7 +770,7 @@ def get_current_case(self, value): def value_to_basic(self, value, app, use_security=False): if self.test_param is None: raise Exception("Must set 'test_param' attribute to use.") - rval = {} + rval: Dict[str, Any] = {} rval[self.test_param.name] = self.test_param.value_to_basic(value[self.test_param.name], app) current_case = rval["__current_case__"] = self.get_current_case(value[self.test_param.name]) for input in self.cases[current_case].inputs.values(): diff --git a/lib/galaxy/tools/parameters/meta.py b/lib/galaxy/tools/parameters/meta.py index fee9f6d6079e..f2d8ba1a68d1 100644 --- a/lib/galaxy/tools/parameters/meta.py +++ b/lib/galaxy/tools/parameters/meta.py @@ -22,6 +22,10 @@ from galaxy.util import permutations from . import visit_input_values from .wrapped import process_key +from .._types import ( + ToolRequestT, + ToolStateJobInstanceT, +) log = logging.getLogger(__name__) @@ -154,10 +158,10 @@ def is_batch(value): return WorkflowParameterExpansion(param_combinations, params_keys, input_combinations) -ExpandedT = Tuple[List[Dict[str, Any]], Optional[matching.MatchingCollections]] +ExpandedT = Tuple[List[ToolStateJobInstanceT], Optional[matching.MatchingCollections]] -def expand_meta_parameters(trans, tool, incoming) -> ExpandedT: +def expand_meta_parameters(trans, tool, incoming: ToolRequestT) -> ExpandedT: """ Take in a dictionary of raw incoming parameters and expand to a list of expanded incoming parameters (one set of parameters per tool diff --git a/lib/galaxy/webapps/galaxy/api/jobs.py b/lib/galaxy/webapps/galaxy/api/jobs.py index 6aebebe5ec3c..9eb5efb40938 100644 --- a/lib/galaxy/webapps/galaxy/api/jobs.py +++ b/lib/galaxy/webapps/galaxy/api/jobs.py @@ -71,7 +71,7 @@ JobIndexViewEnum, JobsService, ) -from galaxy.work.context import WorkRequestContext +from galaxy.work.context import proxy_work_context_for_history log = logging.getLogger(__name__) @@ -478,10 +478,8 @@ def search( for k, v in payload.__annotations__.items(): if k.startswith("files_") or k.startswith("__files_"): inputs[k] = v - request_context = WorkRequestContext(app=trans.app, user=trans.user, history=trans.history) - all_params, all_errors, _, _ = tool.expand_incoming( - trans=trans, incoming=inputs, request_context=request_context - ) + request_context = proxy_work_context_for_history(trans) + all_params, all_errors, _, _ = tool.expand_incoming(request_context, incoming=inputs) if any(all_errors): return [] params_dump = [tool.params_to_strings(param, trans.app, nested=True) for param in all_params] diff --git a/lib/galaxy/webapps/galaxy/api/workflows.py b/lib/galaxy/webapps/galaxy/api/workflows.py index 5de6e9d9b47c..e63063f871a4 100644 --- a/lib/galaxy/webapps/galaxy/api/workflows.py +++ b/lib/galaxy/webapps/galaxy/api/workflows.py @@ -83,6 +83,7 @@ from galaxy.structured_app import StructuredApp from galaxy.tool_shed.galaxy_install.install_manager import InstallRepositoryManager from galaxy.tools import recommendations +from galaxy.tools._types import ParameterValidationErrorsT from galaxy.tools.parameters import populate_state from galaxy.tools.parameters.workflow_utils import workflow_building_modes from galaxy.web import ( @@ -537,7 +538,7 @@ def build_module(self, trans: GalaxyWebTransaction, payload=None): module = module_factory.from_dict(trans, payload, from_tool_form=True) if "tool_state" not in payload: module_state: Dict[str, Any] = {} - errors: Dict[str, str] = {} + errors: ParameterValidationErrorsT = {} populate_state(trans, module.get_inputs(), inputs, module_state, errors=errors, check=True) module.recover_state(module_state, from_tool_form=True) module.check_and_update_state() diff --git a/lib/galaxy/work/context.py b/lib/galaxy/work/context.py index 025fb7aac920..81db0f7e1c6c 100644 --- a/lib/galaxy/work/context.py +++ b/lib/galaxy/work/context.py @@ -175,7 +175,7 @@ def set_history(self, history): def proxy_work_context_for_history( trans: ProvidesHistoryContext, history: Optional[History] = None, workflow_building_mode=False -): +) -> WorkRequestContext: """Create a WorkContext for supplied context with potentially different history. This provides semi-structured access to a transaction/work context with a supplied target diff --git a/lib/galaxy/workflow/modules.py b/lib/galaxy/workflow/modules.py index 9a35e4f9b07b..c6a6af2f9090 100644 --- a/lib/galaxy/workflow/modules.py +++ b/lib/galaxy/workflow/modules.py @@ -30,6 +30,7 @@ ) from galaxy.job_execution.actions.post import ActionBox from galaxy.model import ( + Job, PostJobAction, Workflow, WorkflowInvocationStep, @@ -1209,8 +1210,7 @@ def get_inputs(self): option[2] = True input_parameter_type.static_options[i] = tuple(option) - parameter_type_cond = Conditional() - parameter_type_cond.name = "parameter_definition" + parameter_type_cond = Conditional("parameter_definition") parameter_type_cond.test_param = input_parameter_type cases = [] @@ -1262,8 +1262,7 @@ def get_inputs(self): input_default_value = ColorToolParameter(None, default_source) optional_value = optional_param(optional) - optional_cond = Conditional() - optional_cond.name = "optional" + optional_cond = Conditional("optional") optional_cond.test_param = optional_value when_this_type = ConditionalWhen() @@ -1276,8 +1275,7 @@ def get_inputs(self): name="specify_default", label="Specify a default value", type="boolean", checked=specify_default_checked ) specify_default = BooleanToolParameter(None, specify_default_source) - specify_default_cond = Conditional() - specify_default_cond.name = "specify_default" + specify_default_cond = Conditional("specify_default") specify_default_cond.test_param = specify_default when_specify_default_true = ConditionalWhen() @@ -1350,9 +1348,8 @@ def get_inputs(self): "selected": restrict_how_value == "staticSuggestions", }, ] - restrictions_cond = Conditional() + restrictions_cond = Conditional("restrictions") restrictions_how = SelectToolParameter(None, restrict_how_source) - restrictions_cond.name = "restrictions" restrictions_cond.test_param = restrictions_how when_restrict_none = ConditionalWhen() @@ -2290,19 +2287,7 @@ def callback(input, prefixed_name: str, **kwargs): param_combinations.append(execution_state.inputs) complete = False - completed_jobs = {} - for i, param in enumerate(param_combinations): - if use_cached_job: - completed_jobs[i] = tool.job_search.by_tool_input( - trans=trans, - tool_id=tool.id, - tool_version=tool.version, - param=param, - param_dump=tool.params_to_strings(param, trans.app, nested=True), - job_state=None, - ) - else: - completed_jobs[i] = None + completed_jobs: Dict[int, Optional[Job]] = tool.completed_jobs(trans, use_cached_job, param_combinations) try: mapping_params = MappingParameters(tool_state.inputs, param_combinations) max_num_jobs = progress.maximum_jobs_to_schedule_or_none diff --git a/test/unit/app/tools/test_evaluation.py b/test/unit/app/tools/test_evaluation.py index 52d25566b638..e571b5fd8898 100644 --- a/test/unit/app/tools/test_evaluation.py +++ b/test/unit/app/tools/test_evaluation.py @@ -57,8 +57,7 @@ def test_simple_evaluation(self): assert command_line == "bwa --thresh=4 --in=/galaxy/files/dataset_1.dat --out=/galaxy/files/dataset_2.dat" def test_repeat_evaluation(self): - repeat = Repeat() - repeat.name = "r" + repeat = Repeat("r") repeat.inputs = {"thresh": self.tool.test_thresh_param()} self.tool.set_params({"r": repeat}) self.job.parameters = [ @@ -85,8 +84,7 @@ def test_conditional_evaluation(self): select_xml = XML("""""") parameter = SelectToolParameter(self.tool, select_xml) - conditional = Conditional() - conditional.name = "c" + conditional = Conditional("c") conditional.test_param = parameter when = ConditionalWhen() when.inputs = {"thresh": self.tool.test_thresh_param()}