From cb950f54332bcb8bdc1806805fdae7a1c2a199ea Mon Sep 17 00:00:00 2001 From: Nicola Soranzo Date: Wed, 11 Sep 2024 17:47:49 +0100 Subject: [PATCH] Fix signature override of `Tool.exec_after_process()` method Fix the following mypy 1.11 error: ``` lib/galaxy/tools/__init__.py:3212: error: Signature of "exec_after_process" incompatible with supertype "Tool" [override] def exec_after_process(self, app, inp_data, out_data, param_dict, ... ^ lib/galaxy/tools/__init__.py:3212: note: Superclass: lib/galaxy/tools/__init__.py:3212: note: def exec_after_process(self, app: Any, inp_data: Any, out_data: Any, param_dict: Any, job: Any = ..., **kwds: Any) -> Any lib/galaxy/tools/__init__.py:3212: note: Subclass: lib/galaxy/tools/__init__.py:3212: note: def exec_after_process(self, app: Any, inp_data: Any, out_data: Any, param_dict: Any, job: Any, final_job_state: Any = ..., **kwds: Any) -> Any ``` Also: - Add some type annotions - Small refactorings --- lib/galaxy/jobs/__init__.py | 23 ++++++++++++++++++----- lib/galaxy/jobs/command_factory.py | 1 + lib/galaxy/jobs/runners/__init__.py | 6 ++++-- lib/galaxy/tools/__init__.py | 20 +++++++++----------- 4 files changed, 32 insertions(+), 18 deletions(-) diff --git a/lib/galaxy/jobs/__init__.py b/lib/galaxy/jobs/__init__.py index f58fa84b95ad..fbdfe78c706b 100644 --- a/lib/galaxy/jobs/__init__.py +++ b/lib/galaxy/jobs/__init__.py @@ -20,6 +20,7 @@ Dict, Iterable, List, + Optional, TYPE_CHECKING, ) @@ -99,6 +100,7 @@ if TYPE_CHECKING: from galaxy.jobs.handler import JobHandlerQueue + from galaxy.tools import Tool log = logging.getLogger(__name__) @@ -984,11 +986,17 @@ class MinimalJobWrapper(HasResourceParameters): is_task = False - def __init__(self, job: model.Job, app: MinimalManagerApp, use_persisted_destination: bool = False, tool=None): + def __init__( + self, + job: model.Job, + app: MinimalManagerApp, + use_persisted_destination: bool = False, + tool: Optional["Tool"] = None, + ): self.job_id = job.id self.session_id = job.session_id self.user_id = job.user_id - self.app: MinimalManagerApp = app + self.app = app self.tool = tool self.sa_session = self.app.model.context self.extra_filenames: List[str] = [] @@ -2531,10 +2539,15 @@ def set_container(self, container): class JobWrapper(MinimalJobWrapper): - def __init__(self, job, queue: "JobHandlerQueue", use_persisted_destination=False, app=None): - super().__init__(job, app=queue.app, use_persisted_destination=use_persisted_destination) + def __init__(self, job, queue: "JobHandlerQueue", use_persisted_destination=False): + app = queue.app + super().__init__( + job, + app=app, + use_persisted_destination=use_persisted_destination, + tool=app.toolbox.get_tool(job.tool_id, job.tool_version, exact=True), + ) self.queue = queue - self.tool = self.app.toolbox.get_tool(job.tool_id, job.tool_version, exact=True) self.job_runner_mapper = JobRunnerMapper(self, queue.dispatcher.url_to_destination, self.app.job_config) if use_persisted_destination: self.job_runner_mapper.cached_job_destination = JobDestination(from_job=job) diff --git a/lib/galaxy/jobs/command_factory.py b/lib/galaxy/jobs/command_factory.py index 9ff10b079723..3e920fa7f52b 100644 --- a/lib/galaxy/jobs/command_factory.py +++ b/lib/galaxy/jobs/command_factory.py @@ -142,6 +142,7 @@ def build_command( if job_wrapper.is_cwl_job: # Minimal metadata needed by the relocate script + assert job_wrapper.tool cwl_metadata_params = { "job_metadata": join("working", job_wrapper.tool.provided_metadata_file), "job_id_tag": job_wrapper.get_id_tag(), diff --git a/lib/galaxy/jobs/runners/__init__.py b/lib/galaxy/jobs/runners/__init__.py index bde5baedeb08..511c431ac8e5 100644 --- a/lib/galaxy/jobs/runners/__init__.py +++ b/lib/galaxy/jobs/runners/__init__.py @@ -502,6 +502,7 @@ def get_job_file(self, job_wrapper: "MinimalJobWrapper", **kwds) -> str: env_setup_commands.append(env_to_statement(env)) command_line = job_wrapper.runner_command_line tmp_dir_creation_statement = job_wrapper.tmp_dir_creation_statement + assert job_wrapper.tool options = dict( tmp_dir_creation_statement=tmp_dir_creation_statement, job_instrumenter=job_instrumenter, @@ -538,13 +539,14 @@ def _find_container( if not compute_job_directory: compute_job_directory = job_wrapper.working_directory + tool = job_wrapper.tool + assert tool if not compute_tool_directory: - compute_tool_directory = job_wrapper.tool.tool_dir + compute_tool_directory = tool.tool_dir if not compute_tmp_directory: compute_tmp_directory = job_wrapper.tmp_directory() - tool = job_wrapper.tool guest_ports = job_wrapper.guest_ports tool_info = ToolInfo( tool.containers, diff --git a/lib/galaxy/tools/__init__.py b/lib/galaxy/tools/__init__.py index 4e081232dbea..56db376549fd 100644 --- a/lib/galaxy/tools/__init__.py +++ b/lib/galaxy/tools/__init__.py @@ -2346,7 +2346,7 @@ def call_hook(self, hook_name, *args, **kwargs): def exec_before_job(self, app, inp_data, out_data, param_dict=None): pass - def exec_after_process(self, app, inp_data, out_data, param_dict, job=None, **kwds): + def exec_after_process(self, app, inp_data, out_data, param_dict, job, final_job_state: Optional[str] = None): pass def job_failed(self, job_wrapper, message, exception=False): @@ -2977,7 +2977,7 @@ def exec_before_job(self, app, inp_data, out_data, param_dict=None): with open(expression_inputs_path, "w") as f: json.dump(expression_inputs, f) - def exec_after_process(self, app, inp_data, out_data, param_dict, job=None, **kwds): + def exec_after_process(self, app, inp_data, out_data, param_dict, job, final_job_state=None): for key, val in self.outputs.items(): if key not in out_data: # Skip filtered outputs @@ -3151,7 +3151,7 @@ def regenerate_imported_metadata_if_needed( ) self.app.job_manager.enqueue(job=job, tool=self) - def exec_after_process(self, app, inp_data, out_data, param_dict, job=None, **kwds): + def exec_after_process(self, app, inp_data, out_data, param_dict, job, final_job_state=None): working_directory = app.object_store.get_filename(job, base_dir="job_work", dir_only=True, obj_dir=True) for name, dataset in inp_data.items(): external_metadata = get_metadata_compute_strategy(app.config, job.id, tool_id=self.id) @@ -3209,8 +3209,8 @@ class ExportHistoryTool(Tool): class ImportHistoryTool(Tool): tool_type = "import_history" - def exec_after_process(self, app, inp_data, out_data, param_dict, job, final_job_state=None, **kwds): - super().exec_after_process(app, inp_data, out_data, param_dict, job=job, **kwds) + def exec_after_process(self, app, inp_data, out_data, param_dict, job, final_job_state=None): + super().exec_after_process(app, inp_data, out_data, param_dict, job=job, final_job_state=final_job_state) if final_job_state != DETECTED_JOB_STATE.OK: return JobImportHistoryArchiveWrapper(self.app, job.id).cleanup_after_job() @@ -3234,9 +3234,8 @@ def __remove_interactivetool_by_job(self, job): else: log.warning("Could not determine job to stop InteractiveTool: %s", job) - def exec_after_process(self, app, inp_data, out_data, param_dict, job=None, **kwds): - # run original exec_after_process - super().exec_after_process(app, inp_data, out_data, param_dict, job=job, **kwds) + def exec_after_process(self, app, inp_data, out_data, param_dict, job, final_job_state=None): + super().exec_after_process(app, inp_data, out_data, param_dict, job=job, final_job_state=final_job_state) self.__remove_interactivetool_by_job(job) def job_failed(self, job_wrapper, message, exception=False): @@ -3255,12 +3254,11 @@ def __init__(self, config_file, root, app, guid=None, data_manager_id=None, **kw if self.data_manager_id is None: self.data_manager_id = self.id - def exec_after_process(self, app, inp_data, out_data, param_dict, job=None, final_job_state=None, **kwds): + def exec_after_process(self, app, inp_data, out_data, param_dict, job, final_job_state=None): assert self.allow_user_access(job.user), "You must be an admin to access this tool." if final_job_state != DETECTED_JOB_STATE.OK: return - # run original exec_after_process - super().exec_after_process(app, inp_data, out_data, param_dict, job=job, **kwds) + super().exec_after_process(app, inp_data, out_data, param_dict, job=job, final_job_state=final_job_state) # process results of tool data_manager_id = job.data_manager_association.data_manager_id data_manager = self.app.data_managers.get_manager(data_manager_id)