Skip to content

Commit

Permalink
Merge pull request #18626 from jmchilton/tool_execution_typing
Browse files Browse the repository at this point in the history
Better Typing for Tool Execution Plumbing
  • Loading branch information
mvdbeek authored Aug 6, 2024
2 parents d1e0607 + 3345de1 commit 94611cb
Show file tree
Hide file tree
Showing 14 changed files with 519 additions and 157 deletions.
4 changes: 2 additions & 2 deletions lib/galaxy/managers/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ def set_metadata(self, trans, dataset_assoc, overwrite=False, validate=True):
if overwrite:
self.overwrite_metadata(data)

job, *_ = self.app.datatypes_registry.set_external_metadata_tool.tool_action.execute(
job, *_ = self.app.datatypes_registry.set_external_metadata_tool.tool_action.execute_via_trans(
self.app.datatypes_registry.set_external_metadata_tool,
trans,
incoming={"input1": data, "validate": validate},
Expand Down Expand Up @@ -866,7 +866,7 @@ def deserialize_datatype(self, item, key, val, **context):
assert (
trans
), "Logic error in Galaxy, deserialize_datatype not send a transation object" # TODO: restructure this for stronger typing
job, *_ = self.app.datatypes_registry.set_external_metadata_tool.tool_action.execute(
job, *_ = self.app.datatypes_registry.set_external_metadata_tool.tool_action.execute_via_trans(
self.app.datatypes_registry.set_external_metadata_tool, trans, incoming={"input1": item}, overwrite=False
) # overwrite is False as per existing behavior
trans.app.job_manager.enqueue(job, tool=trans.app.datatypes_registry.set_external_metadata_tool)
Expand Down
2 changes: 1 addition & 1 deletion lib/galaxy/managers/histories.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ def queue_history_export(

# Run job to do export.
history_exp_tool = trans.app.toolbox.get_tool(export_tool_id)
job, *_ = history_exp_tool.execute(trans, incoming=params, history=history, set_output_hid=True)
job, *_ = history_exp_tool.execute(trans, incoming=params, history=history)
trans.app.job_manager.enqueue(job, tool=history_exp_tool)
return job

Expand Down
101 changes: 82 additions & 19 deletions lib/galaxy/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
StoredWorkflow,
)
from galaxy.model.base import transaction
from galaxy.model.dataset_collections.matching import MatchingCollections
from galaxy.tool_shed.util.repository_util import get_installed_repository
from galaxy.tool_shed.util.shed_util_common import set_image_paths
from galaxy.tool_util.deps import (
Expand Down Expand Up @@ -113,6 +114,7 @@
from galaxy.tools.actions.model_operations import ModelOperationToolAction
from galaxy.tools.cache import ToolDocumentCache
from galaxy.tools.evaluation import global_tool_errors
from galaxy.tools.execution_helpers import ToolExecutionCache
from galaxy.tools.imp_exp import JobImportHistoryArchiveWrapper
from galaxy.tools.parameters import (
check_param,
Expand Down Expand Up @@ -183,8 +185,18 @@
from galaxy.version import VERSION_MAJOR
from galaxy.work.context import proxy_work_context_for_history
from .execute import (
DatasetCollectionElementsSliceT,
DEFAULT_JOB_CALLBACK,
DEFAULT_PREFERRED_OBJECT_STORE_ID,
DEFAULT_RERUN_REMAP_JOB_ID,
DEFAULT_SET_OUTPUT_HID,
DEFAULT_USE_CACHED_JOB,
execute as execute_job,
ExecutionSlice,
JobCallbackT,
MappingParameters,
ToolParameterRequestInstanceT,
ToolParameterRequestT,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -1862,11 +1874,11 @@ def expand_incoming(self, trans, incoming, request_context, input_format="legacy
def handle_input(
self,
trans,
incoming,
history=None,
use_cached_job=False,
preferred_object_store_id: Optional[str] = None,
input_format="legacy",
incoming: ToolParameterRequestT,
history: Optional[model.History] = None,
use_cached_job: bool = DEFAULT_USE_CACHED_JOB,
preferred_object_store_id: Optional[str] = DEFAULT_PREFERRED_OBJECT_STORE_ID,
input_format: str = "legacy",
):
"""
Process incoming parameters for this tool from the dict `incoming`,
Expand Down Expand Up @@ -1942,23 +1954,23 @@ def handle_incoming_errors(self, all_errors):
def handle_single_execution(
self,
trans,
rerun_remap_job_id,
execution_slice,
history,
execution_cache=None,
completed_job=None,
collection_info=None,
job_callback=None,
preferred_object_store_id=None,
flush_job=True,
skip=False,
rerun_remap_job_id: Optional[int],
execution_slice: ExecutionSlice,
history: model.History,
execution_cache: ToolExecutionCache,
completed_job: Optional[model.Job],
collection_info: Optional[MatchingCollections],
job_callback: Optional[JobCallbackT],
preferred_object_store_id: Optional[str],
flush_job: bool,
skip: bool,
):
"""
Return a pair with whether execution is successful as well as either
resulting output data or an error message indicating the problem.
"""
try:
rval = self.execute(
rval = self._execute(
trans,
incoming=execution_slice.param_combination,
history=history,
Expand Down Expand Up @@ -2045,18 +2057,67 @@ def get_static_param_values(self, trans):
args[key] = param.get_initial_value(trans, None)
return args

def execute(self, trans, incoming=None, set_output_hid=True, history=None, **kwargs):
def execute(
self,
trans,
incoming: Optional[ToolParameterRequestInstanceT] = None,
history: Optional[model.History] = None,
set_output_hid: bool = DEFAULT_SET_OUTPUT_HID,
flush_job: bool = True,
):
"""
Execute the tool using parameter values in `incoming`. This just
dispatches to the `ToolAction` instance specified by
`self.tool_action`. In general this will create a `Job` that
when run will build the tool's outputs, e.g. `DefaultToolAction`.
_execute has many more options but should be accessed through
handle_single_execution. The public interface to execute should be
rarely used and in more specific ways.
"""
return self._execute(
trans,
incoming=incoming,
history=history,
set_output_hid=set_output_hid,
flush_job=flush_job,
)

def _execute(
self,
trans,
incoming: Optional[ToolParameterRequestInstanceT] = None,
history: Optional[model.History] = None,
rerun_remap_job_id: Optional[int] = DEFAULT_RERUN_REMAP_JOB_ID,
execution_cache: Optional[ToolExecutionCache] = None,
dataset_collection_elements: Optional[DatasetCollectionElementsSliceT] = None,
completed_job: Optional[model.Job] = None,
collection_info: Optional[MatchingCollections] = None,
job_callback: Optional[JobCallbackT] = DEFAULT_JOB_CALLBACK,
preferred_object_store_id: Optional[str] = DEFAULT_PREFERRED_OBJECT_STORE_ID,
set_output_hid: bool = DEFAULT_SET_OUTPUT_HID,
flush_job: bool = True,
skip: bool = False,
):
if incoming is None:
incoming = {}
try:
return self.tool_action.execute(
self, trans, incoming=incoming, set_output_hid=set_output_hid, history=history, **kwargs
self,
trans,
incoming=incoming,
history=history,
job_params=None,
rerun_remap_job_id=rerun_remap_job_id,
execution_cache=execution_cache,
dataset_collection_elements=dataset_collection_elements,
completed_job=completed_job,
collection_info=collection_info,
job_callback=job_callback,
preferred_object_store_id=preferred_object_store_id,
set_output_hid=set_output_hid,
flush_job=flush_job,
skip=skip,
)
except exceptions.ToolExecutionError as exc:
job = exc.job
Expand Down Expand Up @@ -2988,7 +3049,9 @@ class SetMetadataTool(Tool):
requires_setting_metadata = False
tool_action: "SetMetadataToolAction"

def regenerate_imported_metadata_if_needed(self, hda, history, user, session_id):
def regenerate_imported_metadata_if_needed(
self, hda: model.HistoryDatasetAssociation, history: model.History, user: model.User, session_id: int
):
if hda.has_metadata_files:
job, *_ = self.tool_action.execute_via_app(
self,
Expand Down
132 changes: 55 additions & 77 deletions lib/galaxy/tools/actions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
cast,
Dict,
List,
Optional,
Set,
Tuple,
TYPE_CHECKING,
Union,
)
Expand All @@ -24,15 +26,32 @@
from galaxy.job_execution.actions.post import ActionBox
from galaxy.managers.context import ProvidesHistoryContext
from galaxy.model import (
History,
HistoryDatasetAssociation,
Job,
LibraryDatasetDatasetAssociation,
WorkflowRequestInputParameter,
)
from galaxy.model.base import transaction
from galaxy.model.dataset_collections.builder import CollectionBuilder
from galaxy.model.dataset_collections.matching import MatchingCollections
from galaxy.model.none_like import NoneDataset
from galaxy.objectstore import ObjectStorePopulator
from galaxy.tools.execute import (
DatasetCollectionElementsSliceT,
DEFAULT_DATASET_COLLECTION_ELEMENTS,
DEFAULT_JOB_CALLBACK,
DEFAULT_PREFERRED_OBJECT_STORE_ID,
DEFAULT_RERUN_REMAP_JOB_ID,
DEFAULT_SET_OUTPUT_HID,
JobCallbackT,
ToolParameterRequestInstanceT,
)
from galaxy.tools.execution_helpers import (
filter_output,
on_text_for_names,
ToolExecutionCache,
)
from galaxy.tools.parameters import update_dataset_ids
from galaxy.tools.parameters.basic import (
DataCollectionToolParameter,
Expand All @@ -54,32 +73,8 @@
log = logging.getLogger(__name__)


class ToolExecutionCache:
"""An object mean to cache calculation caused by repeatedly evaluting
the same tool by the same user with slightly different parameters.
"""

def __init__(self, trans):
self.trans = trans
self.current_user_roles = trans.get_current_user_roles()
self.chrom_info = {}
self.cached_collection_elements = {}

def get_chrom_info(self, tool_id, input_dbkey):
genome_builds = self.trans.app.genome_builds
custom_build_hack_get_len_from_fasta_conversion = tool_id != "CONVERTER_fasta_to_len"
if custom_build_hack_get_len_from_fasta_conversion and input_dbkey in self.chrom_info:
return self.chrom_info[input_dbkey]

chrom_info_pair = genome_builds.get_chrom_info(
input_dbkey,
trans=self.trans,
custom_build_hack_get_len_from_fasta_conversion=custom_build_hack_get_len_from_fasta_conversion,
)
if custom_build_hack_get_len_from_fasta_conversion:
self.chrom_info[input_dbkey] = chrom_info_pair

return chrom_info_pair
OutputDatasetsT = Dict[str, "DatasetInstance"]
ToolActionExecuteResult = Union[Tuple[Job, OutputDatasetsT, Optional[History]], Tuple[Job, OutputDatasetsT]]


class ToolAction:
Expand All @@ -89,14 +84,31 @@ class ToolAction:
"""

@abstractmethod
def execute(self, tool, trans, incoming=None, set_output_hid=True, **kwargs):
pass
def execute(
self,
tool,
trans,
incoming: Optional[ToolParameterRequestInstanceT] = None,
history: Optional[History] = None,
job_params=None,
rerun_remap_job_id: Optional[int] = DEFAULT_RERUN_REMAP_JOB_ID,
execution_cache: Optional[ToolExecutionCache] = None,
dataset_collection_elements: Optional[DatasetCollectionElementsSliceT] = DEFAULT_DATASET_COLLECTION_ELEMENTS,
completed_job: Optional[Job] = None,
collection_info: Optional[MatchingCollections] = None,
job_callback: Optional[JobCallbackT] = DEFAULT_JOB_CALLBACK,
preferred_object_store_id: Optional[str] = DEFAULT_PREFERRED_OBJECT_STORE_ID,
set_output_hid: bool = DEFAULT_SET_OUTPUT_HID,
flush_job: bool = True,
skip: bool = False,
) -> ToolActionExecuteResult:
"""Perform target tool action."""


class DefaultToolAction(ToolAction):
"""Default tool action is to run an external command"""

produces_real_jobs = True
produces_real_jobs: bool = True

def _collect_input_datasets(
self,
Expand Down Expand Up @@ -389,21 +401,20 @@ def execute(
self,
tool,
trans,
incoming=None,
return_job=False,
set_output_hid=True,
history=None,
incoming: Optional[ToolParameterRequestInstanceT] = None,
history: Optional[History] = None,
job_params=None,
rerun_remap_job_id=None,
execution_cache=None,
rerun_remap_job_id: Optional[int] = DEFAULT_RERUN_REMAP_JOB_ID,
execution_cache: Optional[ToolExecutionCache] = None,
dataset_collection_elements=None,
completed_job=None,
collection_info=None,
job_callback=None,
preferred_object_store_id=None,
flush_job=True,
skip=False,
):
completed_job: Optional[Job] = None,
collection_info: Optional[MatchingCollections] = None,
job_callback: Optional[JobCallbackT] = DEFAULT_JOB_CALLBACK,
preferred_object_store_id: Optional[str] = DEFAULT_PREFERRED_OBJECT_STORE_ID,
set_output_hid: bool = DEFAULT_SET_OUTPUT_HID,
flush_job: bool = True,
skip: bool = False,
) -> ToolActionExecuteResult:
"""
Executes a tool, creating job and tool outputs, associating them, and
submitting the job to the job queue. If history is not specified, use
Expand All @@ -424,6 +435,7 @@ def execute(
preserved_hdca_tags,
all_permissions,
) = self._collect_inputs(tool, trans, incoming, history, current_user_roles, collection_info)
assert history # tell type system we've set history and it is no longer optional
# Build name for output datasets based on tool name and input names
on_text = self._get_on_text(inp_data)

Expand Down Expand Up @@ -846,7 +858,7 @@ def _get_on_text(self, inp_data):

return on_text_for_names(input_names)

def _new_job_for_session(self, trans, tool, history):
def _new_job_for_session(self, trans, tool, history) -> Tuple[model.Job, Optional[model.GalaxySession]]:
job = trans.app.model.Job()
job.galaxy_version = trans.app.config.version_major
galaxy_session = None
Expand Down Expand Up @@ -1097,40 +1109,6 @@ def check_elements(elements):
self.out_collection_instances[name] = hdca


def on_text_for_names(input_names):
# input_names may contain duplicates... this is because the first value in
# multiple input dataset parameters will appear twice once as param_name
# and once as param_name1.
unique_names = []
for name in input_names:
if name not in unique_names:
unique_names.append(name)
input_names = unique_names

# Build name for output datasets based on tool name and input names
if len(input_names) == 0:
on_text = ""
elif len(input_names) == 1:
on_text = input_names[0]
elif len(input_names) == 2:
on_text = "{} and {}".format(*input_names)
elif len(input_names) == 3:
on_text = "{}, {}, and {}".format(*input_names)
else:
on_text = "{}, {}, and others".format(*input_names[:2])
return on_text


def filter_output(tool, output, incoming):
for filter in output.filters:
try:
if not eval(filter.text.strip(), globals(), incoming):
return True # do not create this dataset
except Exception as e:
log.debug(f"Tool {tool.id} output {output.name}: dataset output filter ({filter.text}) failed: {e}")
return False


def get_ext_or_implicit_ext(hda):
if hda.implicitly_converted_parent_datasets:
# implicitly_converted_parent_datasets is a list of ImplicitlyConvertedDatasetAssociation
Expand Down
Loading

0 comments on commit 94611cb

Please sign in to comment.