Skip to content

Commit

Permalink
Typing around tool state.
Browse files Browse the repository at this point in the history
  • Loading branch information
jmchilton committed Aug 29, 2024
1 parent 7227bd2 commit c0573f6
Show file tree
Hide file tree
Showing 13 changed files with 256 additions and 113 deletions.
17 changes: 13 additions & 4 deletions lib/galaxy/managers/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@
Safety,
)
from galaxy.managers.collections import DatasetCollectionManager
from galaxy.managers.context import ProvidesUserContext
from galaxy.managers.context import (
ProvidesHistoryContext,
ProvidesUserContext,
)
from galaxy.managers.datasets import DatasetManager
from galaxy.managers.hdas import HDAManager
from galaxy.managers.lddas import LDDAManager
Expand Down Expand Up @@ -73,6 +76,10 @@
ExecutionTimer,
listify,
)
from galaxy.tools._types import (
ToolStateJobInstancePopulatedT,
ToolStateDumpedToJsonInternalT,
)
from galaxy.util.search import (
FilteredTerm,
parse_filters_structured,
Expand All @@ -81,6 +88,8 @@

log = logging.getLogger(__name__)

JobStateT = str
JobStatesT = Union[JobStateT, List[JobStateT]]

class JobLock(BaseModel):
active: bool = Field(title="Job lock status", description="If active, jobs will not dispatch")
Expand Down Expand Up @@ -311,7 +320,7 @@ def __init__(
self.ldda_manager = ldda_manager
self.decode_id = id_encoding_helper.decode_id

def by_tool_input(self, trans, tool_id, tool_version, param=None, param_dump=None, job_state="ok"):
def by_tool_input(self, trans: ProvidesHistoryContext, tool_id: str, tool_version: Optional[str], param: ToolStateJobInstancePopulatedT, param_dump: ToolStateDumpedToJsonInternalT, job_state: JobStatesT = "ok"):
"""Search for jobs producing same results using the 'inputs' part of a tool POST."""
user = trans.user
input_data = defaultdict(list)
Expand Down Expand Up @@ -353,7 +362,7 @@ def populate_input_data_input_id(path, key, value):
)

def __search(
self, tool_id, tool_version, user, input_data, job_state=None, param_dump=None, wildcard_param_dump=None
self, tool_id: str, tool_version: Optional[str], user: model.User, input_data, job_state=None, param_dump=None, wildcard_param_dump=None
):
search_timer = ExecutionTimer()

Expand Down Expand Up @@ -462,7 +471,7 @@ def replace_dataset_ids(path, key, value):
log.info("No equivalent jobs found %s", search_timer)
return None

def _build_job_subquery(self, tool_id, user_id, tool_version, job_state, wildcard_param_dump):
def _build_job_subquery(self, tool_id: str, user_id: int, tool_version: Optional[str], job_state, wildcard_param_dump):
"""Build subquery that selects a job with correct job parameters."""
stmt = select(model.Job.id).where(
and_(
Expand Down
180 changes: 113 additions & 67 deletions lib/galaxy/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@
check_param,
params_from_strings,
params_to_incoming,
params_to_json_internal,
params_to_json,
params_to_strings,
populate_state,
visit_input_values,
Expand Down Expand Up @@ -183,7 +185,19 @@
get_tool_shed_url_from_tool_shed_registry,
)
from galaxy.version import VERSION_MAJOR
from galaxy.work.context import proxy_work_context_for_history
from galaxy.work.context import (
proxy_work_context_for_history,
WorkRequestContext,
)
from ._types import (
InputFormatT,
ParameterValidationErrorsT,
ToolRequestT,
ToolStateJobInstancePopulatedT,
ToolStateJobInstanceT,
ToolStateDumpedToJsonT,
ToolStateDumpedToJsonInternalT,
)
from .execute import (
DatasetCollectionElementsSliceT,
DEFAULT_JOB_CALLBACK,
Expand All @@ -195,8 +209,6 @@
ExecutionSlice,
JobCallbackT,
MappingParameters,
ToolParameterRequestInstanceT,
ToolParameterRequestT,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -1796,22 +1808,17 @@ def visit_inputs(self, values, callback):
if self.check_values:
visit_input_values(self.inputs, values, callback)

def expand_incoming(self, trans, incoming, request_context, input_format="legacy"):
rerun_remap_job_id = None
if "rerun_remap_job_id" in incoming:
try:
rerun_remap_job_id = trans.app.security.decode_id(incoming["rerun_remap_job_id"])
except Exception as exception:
log.error(str(exception))
raise exceptions.MessageException(
"Failure executing tool with id '%s' (attempting to rerun invalid job).", self.id
)

def expand_incoming(
self, request_context: WorkRequestContext, incoming: ToolRequestT, input_format: InputFormatT = "legacy"
):
rerun_remap_job_id = _rerun_remap_job_id(request_context, incoming, self.id)
set_dataset_matcher_factory(request_context, self)

# Fixed set of input parameters may correspond to any number of jobs.
# Expand these out to individual parameters for given jobs (tool executions).
expanded_incomings, collection_info = expand_meta_parameters(trans, self, incoming)
expanded_incomings: List[ToolStateJobInstanceT]
collection_info: Optional[MatchingCollections]
expanded_incomings, collection_info = expand_meta_parameters(request_context, self, incoming)

# Remapping a single job to many jobs doesn't make sense, so disable
# remap if multi-runs of tools are being used.
Expand All @@ -1832,53 +1839,81 @@ def expand_incoming(self, trans, incoming, request_context, input_format="legacy
"Validated and populated state for tool request",
)
all_errors = []
all_params = []
all_params: List[ToolStateJobInstancePopulatedT] = []

for expanded_incoming in expanded_incomings:
params = {}
errors: Dict[str, str] = {}
if self.input_translator:
self.input_translator.translate(expanded_incoming)
if not self.check_values:
# If `self.check_values` is false we don't do any checking or
# processing on input This is used to pass raw values
# through to/from external sites.
params = expanded_incoming
else:
# Update state for all inputs on the current page taking new
# values from `incoming`.
populate_state(
request_context,
self.inputs,
expanded_incoming,
params,
errors,
simple_errors=False,
input_format=input_format,
)
# If the tool provides a `validate_input` hook, call it.
validate_input = self.get_hook("validate_input")
if validate_input:
# hooks are so terrible ... this is specifically for https://github.com/galaxyproject/tools-devteam/blob/main/tool_collections/gops/basecoverage/operation_filter.py
legacy_non_dce_params = {
k: v.hda if isinstance(v, model.DatasetCollectionElement) and v.hda else v
for k, v in params.items()
}
validate_input(request_context, errors, legacy_non_dce_params, self.inputs)
params, errors = self._populate(request_context, expanded_incoming, input_format)
all_errors.append(errors)
all_params.append(params)
unset_dataset_matcher_factory(request_context)

log.info(validation_timer)
return all_params, all_errors, rerun_remap_job_id, collection_info

def _populate(
self, request_context, expanded_incoming: ToolStateJobInstanceT, input_format: InputFormatT
) -> Tuple[ToolStateJobInstancePopulatedT, ParameterValidationErrorsT]:
"""Validate expanded parameters for a job to replace references with model objects.
So convert a ToolStateJobInstanceT to a ToolStateJobInstancePopulatedT.
"""
params: ToolStateJobInstancePopulatedT = {}
errors: ParameterValidationErrorsT = {}
if self.input_translator:
self.input_translator.translate(expanded_incoming)
if not self.check_values:
# If `self.check_values` is false we don't do any checking or
# processing on input This is used to pass raw values
# through to/from external sites.
params = cast(ToolStateJobInstancePopulatedT, expanded_incoming)
else:
# Update state for all inputs on the current page taking new
# values from `incoming`.
populate_state(
request_context,
self.inputs,
expanded_incoming,
params,
errors,
simple_errors=False,
input_format=input_format,
)
# If the tool provides a `validate_input` hook, call it.
validate_input = self.get_hook("validate_input")
if validate_input:
# hooks are so terrible ... this is specifically for https://github.com/galaxyproject/tools-devteam/blob/main/tool_collections/gops/basecoverage/operation_filter.py
legacy_non_dce_params = {
k: v.hda if isinstance(v, model.DatasetCollectionElement) and v.hda else v
for k, v in params.items()
}
validate_input(request_context, errors, legacy_non_dce_params, self.inputs)
return params, errors

def completed_jobs(self, trans, use_cached_job: bool, all_params: List[ToolStateJobInstancePopulatedT]) -> Dict[int, Optional[model.Job]]:
completed_jobs: Dict[int, Optional[model.Job]] = {}
for i, param in enumerate(all_params):
if use_cached_job:
param_dump: ToolStateDumpedToJsonInternalT = params_to_json_internal(self.inputs, param, self.app)
completed_jobs[i] = self.job_search.by_tool_input(
trans=trans,
tool_id=self.id,
tool_version=self.version,
param=param,
param_dump=param_dump,
job_state=None,
)
else:
completed_jobs[i] = None
return completed_jobs

def handle_input(
self,
trans,
incoming: ToolParameterRequestT,
incoming: ToolRequestT,
history: Optional[model.History] = None,
use_cached_job: bool = DEFAULT_USE_CACHED_JOB,
preferred_object_store_id: Optional[str] = DEFAULT_PREFERRED_OBJECT_STORE_ID,
input_format: str = "legacy",
input_format: InputFormatT = "legacy",
):
"""
Process incoming parameters for this tool from the dict `incoming`,
Expand All @@ -1887,26 +1922,20 @@ def handle_input(
there were no errors).
"""
request_context = proxy_work_context_for_history(trans, history=history)
all_params, all_errors, rerun_remap_job_id, collection_info = self.expand_incoming(

expanded = self.expand_incoming(
trans=trans, incoming=incoming, request_context=request_context, input_format=input_format
)
all_params: List[ToolStateJobInstancePopulatedT] = expanded[0]
all_errors: List[ParameterValidationErrorsT] = expanded[1]
rerun_remap_job_id: Optional[int] = expanded[2]
collection_info: Optional[MatchingCollections] = expanded[3]

# If there were errors, we stay on the same page and display them
self.handle_incoming_errors(all_errors)

mapping_params = MappingParameters(incoming, all_params)
completed_jobs: Dict[int, Optional[model.Job]] = {}
for i, param in enumerate(all_params):
if use_cached_job:
completed_jobs[i] = self.job_search.by_tool_input(
trans=trans,
tool_id=self.id,
tool_version=self.version,
param=param,
param_dump=self.params_to_strings(param, self.app, nested=True),
job_state=None,
)
else:
completed_jobs[i] = None
completed_jobs: Dict[int, Optional[model.Job]] = self.completed_jobs(trans, use_cached_job, all_params)
execution_tracker = execute_job(
trans,
self,
Expand Down Expand Up @@ -1935,7 +1964,7 @@ def handle_input(
implicit_collections=execution_tracker.implicit_collections,
)

def handle_incoming_errors(self, all_errors):
def handle_incoming_errors(self, all_errors: List[ParameterValidationErrorsT]):
if any(all_errors):
# simple param_key -> message string for tool form.
err_data = {key: unicodify(value) for d in all_errors for (key, value) in d.items()}
Expand Down Expand Up @@ -2060,7 +2089,7 @@ def get_static_param_values(self, trans):
def execute(
self,
trans,
incoming: Optional[ToolParameterRequestInstanceT] = None,
incoming: Optional[ToolStateJobInstancePopulatedT] = None,
history: Optional[model.History] = None,
set_output_hid: bool = DEFAULT_SET_OUTPUT_HID,
flush_job: bool = True,
Expand All @@ -2086,7 +2115,7 @@ def execute(
def _execute(
self,
trans,
incoming: Optional[ToolParameterRequestInstanceT] = None,
incoming: Optional[ToolStateJobInstancePopulatedT] = None,
history: Optional[model.History] = None,
rerun_remap_job_id: Optional[int] = DEFAULT_RERUN_REMAP_JOB_ID,
execution_cache: Optional[ToolExecutionCache] = None,
Expand Down Expand Up @@ -2128,7 +2157,7 @@ def _execute(
log.error("Tool execution failed for job: %s", job_id)
raise

def params_to_strings(self, params, app, nested=False):
def params_to_strings(self, params: ToolStateJobInstancePopulatedT, app, nested=False):
return params_to_strings(self.inputs, params, app, nested)

def params_from_strings(self, params, app, ignore_errors=False):
Expand Down Expand Up @@ -2581,6 +2610,8 @@ def to_json(self, trans, kwd=None, job=None, workflow_building_mode=False, histo
else:
action = self.app.url_for(self.action)

state_inputs: ToolStateDumpedToJsonT = params_to_json(self.inputs, state_inputs, self.app)

# update tool model
tool_model.update(
{
Expand All @@ -2594,7 +2625,7 @@ def to_json(self, trans, kwd=None, job=None, workflow_building_mode=False, histo
"requirements": [{"name": r.name, "version": r.version} for r in self.requirements],
"errors": state_errors,
"tool_errors": self.tool_errors,
"state_inputs": params_to_strings(self.inputs, state_inputs, self.app, use_security=True, nested=True),
"state_inputs": state_inputs,
"job_id": trans.security.encode_id(job.id) if job else None,
"job_remap": job.remappable() if job else None,
"history_id": trans.security.encode_id(history.id) if history else None,
Expand Down Expand Up @@ -4110,6 +4141,21 @@ def produce_outputs(self, trans, out_data, output_collections, incoming, history


# ---- Utility classes to be factored out -----------------------------------


def _rerun_remap_job_id(trans, incoming, tool_id: str) -> Optional[int]:
rerun_remap_job_id = None
if "rerun_remap_job_id" in incoming:
try:
rerun_remap_job_id = trans.app.security.decode_id(incoming["rerun_remap_job_id"])
except Exception as exception:
log.error(str(exception))
raise exceptions.MessageException(
"Failure executing tool with id '%s' (attempting to rerun invalid job).", tool_id
)
return rerun_remap_job_id


class TracksterConfig:
"""Trackster configuration encapsulation."""

Expand Down
Loading

0 comments on commit c0573f6

Please sign in to comment.