Skip to content

Commit

Permalink
Improved typing/code structure around tool actions.
Browse files Browse the repository at this point in the history
  • Loading branch information
jmchilton committed Jul 31, 2024
1 parent fa73985 commit 7f3cd89
Show file tree
Hide file tree
Showing 13 changed files with 383 additions and 77 deletions.
4 changes: 2 additions & 2 deletions lib/galaxy/managers/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ def set_metadata(self, trans, dataset_assoc, overwrite=False, validate=True):
if overwrite:
self.overwrite_metadata(data)

job, *_ = self.app.datatypes_registry.set_external_metadata_tool.tool_action.execute(
job, *_ = self.app.datatypes_registry.set_external_metadata_tool.tool_action.execute_via_trans(
self.app.datatypes_registry.set_external_metadata_tool,
trans,
incoming={"input1": data, "validate": validate},
Expand Down Expand Up @@ -883,7 +883,7 @@ def deserialize_datatype(self, item, key, val, **context):
assert (
trans
), "Logic error in Galaxy, deserialize_datatype not send a transation object" # TODO: restructure this for stronger typing
job, *_ = self.app.datatypes_registry.set_external_metadata_tool.tool_action.execute(
job, *_ = self.app.datatypes_registry.set_external_metadata_tool.tool_action.execute_via_trans(
self.app.datatypes_registry.set_external_metadata_tool, trans, incoming={"input1": item}, overwrite=False
) # overwrite is False as per existing behavior
trans.app.job_manager.enqueue(job, tool=trans.app.datatypes_registry.set_external_metadata_tool)
Expand Down
2 changes: 1 addition & 1 deletion lib/galaxy/managers/histories.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ def queue_history_export(

# Run job to do export.
history_exp_tool = trans.app.toolbox.get_tool(export_tool_id)
job, *_ = history_exp_tool.execute(trans, incoming=params, history=history, set_output_hid=True)
job, *_ = history_exp_tool.execute(trans, incoming=params, history=history)
trans.app.job_manager.enqueue(job, tool=history_exp_tool)
return job

Expand Down
63 changes: 55 additions & 8 deletions lib/galaxy/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
StoredWorkflow,
)
from galaxy.model.base import transaction
from galaxy.model.dataset_collections.matching import MatchingCollections
from galaxy.tool_shed.util.repository_util import get_installed_repository
from galaxy.tool_shed.util.shed_util_common import set_image_paths
from galaxy.tool_util.deps import (
Expand Down Expand Up @@ -107,13 +108,13 @@
from galaxy.tools.actions import (
DefaultToolAction,
ToolAction,
ToolExecutionCache,
)
from galaxy.tools.actions.data_manager import DataManagerToolAction
from galaxy.tools.actions.data_source import DataSourceToolAction
from galaxy.tools.actions.model_operations import ModelOperationToolAction
from galaxy.tools.cache import ToolDocumentCache
from galaxy.tools.evaluation import global_tool_errors
from galaxy.tools.execution_helpers import ToolExecutionCache
from galaxy.tools.imp_exp import JobImportHistoryArchiveWrapper
from galaxy.tools.parameters import (
check_param,
Expand Down Expand Up @@ -184,13 +185,17 @@
from galaxy.version import VERSION_MAJOR
from galaxy.work.context import proxy_work_context_for_history
from .execute import (
DEFAULT_USE_CACHED_JOB,
DatasetCollectionElementsSliceT,
DEFAULT_JOB_CALLBACK,
DEFAULT_PREFERRED_OBJECT_STORE_ID,
DEFAULT_RERUN_REMAP_JOB_ID,
DEFAULT_USE_CACHED_JOB,
execute as execute_job,
ExecutionSlice,
JobCallbackT,
MappingParameters,
ToolParameterRequestInstanceT,
ToolParameterRequestT,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -1868,11 +1873,11 @@ def expand_incoming(self, trans, incoming, request_context, input_format="legacy
def handle_input(
self,
trans,
incoming,
incoming: ToolParameterRequestT,
history: Optional[model.History] = None,
use_cached_job: bool = DEFAULT_USE_CACHED_JOB,
preferred_object_store_id: Optional[str] = DEFAULT_PREFERRED_OBJECT_STORE_ID,
input_format="legacy",
input_format: str = "legacy",
):
"""
Process incoming parameters for this tool from the dict `incoming`,
Expand Down Expand Up @@ -1964,7 +1969,7 @@ def handle_single_execution(
resulting output data or an error message indicating the problem.
"""
try:
rval = self.execute(
rval = self._execute(
trans,
incoming=execution_slice.param_combination,
history=history,
Expand Down Expand Up @@ -2051,18 +2056,58 @@ def get_static_param_values(self, trans):
args[key] = param.get_initial_value(trans, None)
return args

def execute(self, trans, incoming=None, set_output_hid=True, history=None, **kwargs):
def execute(
self, trans, incoming: Optional[ToolParameterRequestInstanceT] = None, history: Optional[model.History] = None
):
"""
Execute the tool using parameter values in `incoming`. This just
dispatches to the `ToolAction` instance specified by
`self.tool_action`. In general this will create a `Job` that
when run will build the tool's outputs, e.g. `DefaultToolAction`.
_execute has many more options but should be accessed through
handle_single_execution. The public interface to execute should be
rarely used and in more specific ways.
"""
return self._execute(
trans,
incoming=incoming,
history=history,
)

def _execute(
self,
trans,
incoming: Optional[ToolParameterRequestInstanceT] = None,
history: Optional[model.History] = None,
rerun_remap_job_id: Optional[int] = DEFAULT_RERUN_REMAP_JOB_ID,
execution_cache: Optional[ToolExecutionCache] = None,
dataset_collection_elements: Optional[DatasetCollectionElementsSliceT] = None,
completed_job: Optional[model.Job] = None,
collection_info: Optional[MatchingCollections] = None,
job_callback: Optional[JobCallbackT] = DEFAULT_JOB_CALLBACK,
preferred_object_store_id: Optional[str] = DEFAULT_PREFERRED_OBJECT_STORE_ID,
flush_job: bool = True,
skip: bool = True,
):
if incoming is None:
incoming = {}
try:
return self.tool_action.execute(
self, trans, incoming=incoming, set_output_hid=set_output_hid, history=history, **kwargs
self,
trans,
incoming=incoming,
history=history,
job_params=None,
rerun_remap_job_id=rerun_remap_job_id,
execution_cache=execution_cache,
dataset_collection_elements=dataset_collection_elements,
completed_job=completed_job,
collection_info=collection_info,
job_callback=job_callback,
preferred_object_store_id=preferred_object_store_id,
flush_job=flush_job,
skip=skip,
)
except exceptions.ToolExecutionError as exc:
job = exc.job
Expand Down Expand Up @@ -2994,7 +3039,9 @@ class SetMetadataTool(Tool):
requires_setting_metadata = False
tool_action: "SetMetadataToolAction"

def regenerate_imported_metadata_if_needed(self, hda, history, user, session_id):
def regenerate_imported_metadata_if_needed(
self, hda: model.HistoryDatasetAssociation, history: model.History, user: model.User, session_id: int
):
if hda.has_metadata_files:
job, *_ = self.tool_action.execute_via_app(
self,
Expand Down
69 changes: 51 additions & 18 deletions lib/galaxy/tools/actions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
cast,
Dict,
List,
Optional,
Set,
Tuple,
TYPE_CHECKING,
Union,
)
Expand All @@ -24,19 +26,30 @@
from galaxy.job_execution.actions.post import ActionBox
from galaxy.managers.context import ProvidesHistoryContext
from galaxy.model import (
History,
HistoryDatasetAssociation,
Job,
LibraryDatasetDatasetAssociation,
WorkflowRequestInputParameter,
)
from galaxy.model.base import transaction
from galaxy.model.dataset_collections.builder import CollectionBuilder
from galaxy.model.dataset_collections.matching import MatchingCollections
from galaxy.model.none_like import NoneDataset
from galaxy.objectstore import ObjectStorePopulator
from galaxy.tools.execute import (
DatasetCollectionElementsSliceT,
DEFAULT_DATASET_COLLECTION_ELEMENTS,
DEFAULT_JOB_CALLBACK,
DEFAULT_PREFERRED_OBJECT_STORE_ID,
DEFAULT_RERUN_REMAP_JOB_ID,
JobCallbackT,
ToolParameterRequestInstanceT,
)
from galaxy.tools.execution_helpers import (
ToolExecutionCache,
filter_output,
on_text_for_names,
ToolExecutionCache,
)
from galaxy.tools.parameters import update_dataset_ids
from galaxy.tools.parameters.basic import (
Expand All @@ -59,21 +72,42 @@
log = logging.getLogger(__name__)


OutputDatasetsT = Dict[str, DatasetInstance]
ToolActionExecuteResult = Union[Tuple[Job, OutputDatasetsT, History], Tuple[Job, OutputDatasetsT]]


class ToolAction:
"""
The actions to be taken when a tool is run (after parameters have
been converted and validated).
"""

@abstractmethod
def execute(self, tool, trans, incoming=None, set_output_hid=True, **kwargs):
pass
def execute(
self,
tool,
trans,
incoming: Optional[ToolParameterRequestInstanceT] = None,
history: Optional[History] = None,
job_params=None,
rerun_remap_job_id: Optional[int] = DEFAULT_RERUN_REMAP_JOB_ID,
execution_cache: Optional[ToolExecutionCache] = None,
dataset_collection_elements: Optional[DatasetCollectionElementsSliceT] = DEFAULT_DATASET_COLLECTION_ELEMENTS,
completed_job: Optional[Job] = None,
collection_info: Optional[MatchingCollections] = None,
job_callback: Optional[JobCallbackT] = DEFAULT_JOB_CALLBACK,
preferred_object_store_id: Optional[str] = DEFAULT_PREFERRED_OBJECT_STORE_ID,
flush_job: bool = True,
skip: bool = False,
) -> ToolActionExecuteResult:
"""Perform target tool action."""


class DefaultToolAction(ToolAction):
"""Default tool action is to run an external command"""

produces_real_jobs = True
produces_real_jobs: bool = True
set_output_hid: bool = True

def _collect_input_datasets(
self,
Expand Down Expand Up @@ -366,21 +400,19 @@ def execute(
self,
tool,
trans,
incoming=None,
return_job=False,
set_output_hid=True,
history=None,
incoming: Optional[ToolParameterRequestInstanceT] = None,
history: Optional[History] = None,
job_params=None,
rerun_remap_job_id=None,
execution_cache=None,
rerun_remap_job_id: Optional[int] = DEFAULT_RERUN_REMAP_JOB_ID,
execution_cache: Optional[ToolExecutionCache] = None,
dataset_collection_elements=None,
completed_job=None,
collection_info=None,
job_callback=None,
preferred_object_store_id=None,
flush_job=True,
skip=False,
):
completed_job: Optional[Job] = None,
collection_info: Optional[MatchingCollections] = None,
job_callback: Optional[JobCallbackT] = DEFAULT_JOB_CALLBACK,
preferred_object_store_id: Optional[str] = DEFAULT_PREFERRED_OBJECT_STORE_ID,
flush_job: bool = True,
skip: bool = False,
) -> ToolActionExecuteResult:
"""
Executes a tool, creating job and tool outputs, associating them, and
submitting the job to the job queue. If history is not specified, use
Expand All @@ -401,6 +433,7 @@ def execute(
preserved_hdca_tags,
all_permissions,
) = self._collect_inputs(tool, trans, incoming, history, current_user_roles, collection_info)
assert history # tell type system we've set history and it is no longer optional
# Build name for output datasets based on tool name and input names
on_text = self._get_on_text(inp_data)

Expand Down Expand Up @@ -646,7 +679,7 @@ def handle_output(name, output, hidden=None):
if name not in incoming and name not in child_dataset_names:
# don't add already existing datasets, i.e. async created
history.stage_addition(data)
history.add_pending_items(set_output_hid=set_output_hid)
history.add_pending_items(set_output_hid=self.set_output_hid)

log.info(add_datasets_timer)
job_setup_timer = ExecutionTimer()
Expand Down
56 changes: 53 additions & 3 deletions lib/galaxy/tools/actions/data_manager.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,66 @@
import logging
from typing import Optional

from galaxy.model.base import transaction
from . import DefaultToolAction
from galaxy.model.dataset_collections.matching import MatchingCollections
from galaxy.models import (
History,
Job,
)
from galaxy.tools.execute import (
DatasetCollectionElementsSliceT,
DEFAULT_DATASET_COLLECTION_ELEMENTS,
DEFAULT_JOB_CALLBACK,
DEFAULT_PREFERRED_OBJECT_STORE_ID,
DEFAULT_RERUN_REMAP_JOB_ID,
JobCallbackT,
ToolParameterRequestInstanceT,
)
from galaxy.tools.execution_helpers import ToolExecutionCache
from . import (
DefaultToolAction,
ToolActionExecuteResult,
)

log = logging.getLogger(__name__)


class DataManagerToolAction(DefaultToolAction):
"""Tool action used for Data Manager Tools"""

def execute(self, tool, trans, **kwds):
rval = super().execute(tool, trans, **kwds)
def execute(
self,
tool,
trans,
incoming: Optional[ToolParameterRequestInstanceT] = None,
history: Optional[History] = None,
job_params=None,
rerun_remap_job_id: Optional[int] = DEFAULT_RERUN_REMAP_JOB_ID,
execution_cache: Optional[ToolExecutionCache] = None,
dataset_collection_elements: Optional[DatasetCollectionElementsSliceT] = DEFAULT_DATASET_COLLECTION_ELEMENTS,
completed_job: Optional[Job] = None,
collection_info: Optional[MatchingCollections] = None,
job_callback: Optional[JobCallbackT] = DEFAULT_JOB_CALLBACK,
preferred_object_store_id: Optional[str] = DEFAULT_PREFERRED_OBJECT_STORE_ID,
flush_job: bool = True,
skip: bool = False,
) -> ToolActionExecuteResult:
rval = super().execute(
tool,
trans,
incoming=incoming,
history=history,
job_params=job_params,
rerun_remap_job_id=rerun_remap_job_id,
execution_cache=execution_cache,
dataset_collection_elements=dataset_collection_elements,
completed_job=completed_job,
collection_info=collection_info,
job_callback=job_callback,
preferred_object_store_id=preferred_object_store_id,
flush_job=flush_job,
skip=skip,
)
if isinstance(rval, tuple) and len(rval) >= 2 and isinstance(rval[0], trans.app.model.Job):
assoc = trans.app.model.DataManagerJobAssociation(job=rval[0], data_manager_id=tool.data_manager_id)
trans.sa_session.add(assoc)
Expand Down
Loading

0 comments on commit 7f3cd89

Please sign in to comment.