Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type annotation improvements #17601

Merged
merged 5 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion lib/galaxy/job_metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@
)
from typing import (
Any,
cast,
Dict,
List,
NamedTuple,
Optional,
TYPE_CHECKING,
)

from galaxy import util
Expand All @@ -34,6 +36,9 @@
Safety,
)

if TYPE_CHECKING:
from galaxy.job_metrics.instrumenters import InstrumentPlugin

log = logging.getLogger(__name__)


Expand Down Expand Up @@ -72,7 +77,7 @@ class JobMetrics:

def __init__(self, conf_file=None, conf_dict=None, **kwargs):
"""Load :class:`JobInstrumenter` objects from specified configuration file."""
self.plugin_classes = self.__plugins_dict()
self.plugin_classes = cast(Dict[str, "InstrumentPlugin"], self.__plugins_dict())
if conf_file and os.path.exists(conf_file):
self.default_job_instrumenter = JobInstrumenter.from_file(self.plugin_classes, conf_file, **kwargs)
elif conf_dict or conf_dict is None:
Expand Down
4 changes: 2 additions & 2 deletions lib/galaxy/job_metrics/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ class FormattedMetric(NamedTuple):
class JobMetricFormatter:
"""Format job metric key-value pairs for human consumption in Web UI."""

def format(self, key: Any, value: Any) -> FormattedMetric:
return FormattedMetric(str(key), str(value))
def format(self, key: str, value: Any) -> FormattedMetric:
return FormattedMetric(key, str(value))


def seconds_to_str(value: int) -> str:
Expand Down
3 changes: 0 additions & 3 deletions lib/galaxy/model/dataset_collections/types/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@ class ListDatasetCollectionType(BaseDatasetCollectionType):

collection_type = "list"

def __init__(self):
pass

def generate_elements(self, elements):
for identifier, element in elements.items():
association = DatasetCollectionElement(
Expand Down
3 changes: 0 additions & 3 deletions lib/galaxy/model/dataset_collections/types/paired.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@ class PairedDatasetCollectionType(BaseDatasetCollectionType):

collection_type = "paired"

def __init__(self):
pass

def generate_elements(self, elements):
if forward_dataset := elements.get(FORWARD_IDENTIFIER):
left_association = DatasetCollectionElement(
Expand Down
6 changes: 3 additions & 3 deletions lib/galaxy/tool_util/cwl/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ def galactic_job_json(
for Galaxy.
"""

datasets = []
dataset_collections = []
datasets: List[Dict[str, Any]] = []
dataset_collections: List[Dict[str, Any]] = []

def response_to_hda(target: UploadTarget, upload_response: Dict[str, Any]) -> Dict[str, str]:
assert isinstance(upload_response, dict), upload_response
Expand Down Expand Up @@ -277,7 +277,7 @@ def replacement_file(value):

return upload_file(file_path, secondary_files_tar_path, filetype=filetype, **kwd)

def replacement_directory(value):
def replacement_directory(value: Dict[str, Any]) -> Dict[str, Any]:
file_path = value.get("location", None) or value.get("path", None)
if file_path is None:
return value
Expand Down
16 changes: 10 additions & 6 deletions lib/galaxy/tools/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
)

from packaging.version import Version
from typing_extensions import TypeAlias

from galaxy.model import (
DatasetCollection,
Expand Down Expand Up @@ -605,6 +606,9 @@ def __bool__(self) -> bool:
__nonzero__ = __bool__


DatasetCollectionElementWrapper: TypeAlias = Union["DatasetCollectionWrapper", DatasetFilenameWrapper]


class DatasetCollectionWrapper(ToolParameterValueWrapper, HasDatasets):
name: Optional[str]
collection: DatasetCollection
Expand Down Expand Up @@ -642,15 +646,15 @@ def __init__(
self.collection = collection

elements = collection.elements
element_instances = {}
element_instances: Dict[str, DatasetCollectionElementWrapper] = {}

element_instance_list = []
element_instance_list: List[DatasetCollectionElementWrapper] = []
for dataset_collection_element in elements:
element_object = dataset_collection_element.element_object
element_identifier = dataset_collection_element.element_identifier

if dataset_collection_element.is_collection:
element_wrapper: Union[DatasetCollectionWrapper, DatasetFilenameWrapper] = DatasetCollectionWrapper(
element_wrapper: DatasetCollectionElementWrapper = DatasetCollectionWrapper(
job_working_directory, dataset_collection_element, **kwargs
)
else:
Expand Down Expand Up @@ -757,15 +761,15 @@ def serialize(
def is_input_supplied(self) -> bool:
return self.__input_supplied

def __getitem__(self, key: Union[str, int]) -> Union[None, "DatasetCollectionWrapper", DatasetFilenameWrapper]:
def __getitem__(self, key: Union[str, int]) -> Optional[DatasetCollectionElementWrapper]:
if not self.__input_supplied:
return None
if isinstance(key, int):
return self.__element_instance_list[key]
else:
return self.__element_instances[key]

def __getattr__(self, key: str) -> Union[None, "DatasetCollectionWrapper", DatasetFilenameWrapper]:
def __getattr__(self, key: str) -> Optional[DatasetCollectionElementWrapper]:
if not self.__input_supplied:
return None
try:
Expand All @@ -775,7 +779,7 @@ def __getattr__(self, key: str) -> Union[None, "DatasetCollectionWrapper", Datas

def __iter__(
self,
) -> Iterator[Union["DatasetCollectionWrapper", DatasetFilenameWrapper]]:
) -> Iterator[DatasetCollectionElementWrapper]:
if not self.__input_supplied:
return [].__iter__()
return self.__element_instance_list.__iter__()
Expand Down
33 changes: 24 additions & 9 deletions lib/galaxy/workflow/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,9 @@ def decode_runtime_state(self, step, runtime_state):
state.decode(runtime_state, Bunch(inputs=self.get_runtime_inputs(step)), self.trans.app)
return state

def execute(self, trans, progress, invocation_step, use_cached_job=False):
def execute(
self, trans, progress: "WorkflowProgress", invocation_step, use_cached_job: bool = False
) -> Optional[bool]:
"""Execute the given workflow invocation step.

Use the supplied workflow progress object to track outputs, find
Expand Down Expand Up @@ -508,7 +510,7 @@ def get_informal_replacement_parameters(self, step) -> List[str]:

return []

def compute_collection_info(self, progress, step, all_inputs):
def compute_collection_info(self, progress: "WorkflowProgress", step, all_inputs):
"""
Use get_all_inputs (if implemented) to determine collection mapping for execution.
"""
Expand All @@ -526,7 +528,7 @@ def compute_collection_info(self, progress, step, all_inputs):
collection_info.when_values = progress.when_values
return collection_info or progress.subworkflow_collection_info

def _find_collections_to_match(self, progress, step, all_inputs):
def _find_collections_to_match(self, progress: "WorkflowProgress", step, all_inputs) -> matching.CollectionsToMatch:
collections_to_match = matching.CollectionsToMatch()
dataset_collection_type_descriptions = self.trans.app.dataset_collection_manager.collection_type_descriptions

Expand Down Expand Up @@ -756,7 +758,9 @@ def get_post_job_actions(self, incoming):
def get_content_id(self):
return self.trans.security.encode_id(self.subworkflow.id)

def execute(self, trans, progress, invocation_step, use_cached_job=False):
def execute(
self, trans, progress: "WorkflowProgress", invocation_step, use_cached_job: bool = False
) -> Optional[bool]:
"""Execute the given workflow step in the given workflow invocation.
Use the supplied workflow progress object to track outputs, find
inputs, etc...
Expand Down Expand Up @@ -929,7 +933,9 @@ def get_runtime_state(self):
def get_all_inputs(self, data_only=False, connectable_only=False):
return []

def execute(self, trans, progress, invocation_step, use_cached_job=False):
def execute(
self, trans, progress: "WorkflowProgress", invocation_step, use_cached_job: bool = False
) -> Optional[bool]:
invocation = invocation_step.workflow_invocation
step = invocation_step.workflow_step
input_value = step.state.inputs["input"]
Expand Down Expand Up @@ -963,6 +969,7 @@ def execute(self, trans, progress, invocation_step, use_cached_job=False):
if content:
invocation.add_input(content, step.id)
progress.set_outputs_for_input(invocation_step, step_outputs)
return None

def recover_mapping(self, invocation_step, progress):
progress.set_outputs_for_input(invocation_step, already_persisted=True)
Expand Down Expand Up @@ -1522,7 +1529,9 @@ def get_all_outputs(self, data_only=False):
)
]

def execute(self, trans, progress, invocation_step, use_cached_job=False):
def execute(
self, trans, progress: "WorkflowProgress", invocation_step, use_cached_job: bool = False
) -> Optional[bool]:
step = invocation_step.workflow_step
input_value = step.state.inputs["input"]
if input_value is None:
Expand All @@ -1535,6 +1544,7 @@ def execute(self, trans, progress, invocation_step, use_cached_job=False):
input_value = default_value.get("value", NO_REPLACEMENT)
step_outputs = dict(output=input_value)
progress.set_outputs_for_input(invocation_step, step_outputs)
return None

def step_state_to_tool_state(self, state):
state = safe_loads(state)
Expand Down Expand Up @@ -1666,9 +1676,12 @@ def get_runtime_state(self):
state.inputs = dict()
return state

def execute(self, trans, progress, invocation_step, use_cached_job=False):
def execute(
self, trans, progress: "WorkflowProgress", invocation_step, use_cached_job: bool = False
) -> Optional[bool]:
step = invocation_step.workflow_step
progress.mark_step_outputs_delayed(step, why="executing pause step")
return None

def recover_mapping(self, invocation_step, progress):
if invocation_step:
Expand Down Expand Up @@ -2131,7 +2144,9 @@ def decode_runtime_state(self, step, runtime_state):
f"Tool {self.tool_id} missing. Cannot recover runtime state.", tool_id=self.tool_id
)

def execute(self, trans, progress, invocation_step, use_cached_job=False):
def execute(
self, trans, progress: "WorkflowProgress", invocation_step, use_cached_job: bool = False
) -> Optional[bool]:
invocation = invocation_step.workflow_invocation
step = invocation_step.workflow_step
tool = trans.app.toolbox.get_tool(step.tool_id, tool_version=step.tool_version, tool_uuid=step.tool_uuid)
Expand Down Expand Up @@ -2171,7 +2186,7 @@ def execute(self, trans, progress, invocation_step, use_cached_job=False):
found_replacement_keys = set()

# Connect up
def callback(input, prefixed_name, **kwargs):
def callback(input, prefixed_name: str, **kwargs):
input_dict = all_inputs_by_name[prefixed_name]

replacement: Union[model.Dataset, NoReplacement] = NO_REPLACEMENT
Expand Down
15 changes: 10 additions & 5 deletions lib/galaxy/workflow/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ def queue_invoke(


class WorkflowInvoker:
progress: "WorkflowProgress"

def __init__(
self,
trans: "WorkRequestContext",
Expand Down Expand Up @@ -348,7 +350,7 @@ class WorkflowProgress:
def __init__(
self,
workflow_invocation: WorkflowInvocation,
inputs_by_step_id: Any,
inputs_by_step_id: Dict[int, Any],
module_injector: ModuleInjector,
param_map: Dict[int, Dict[str, Any]],
jobs_per_scheduling_iteration: int = -1,
Expand Down Expand Up @@ -415,7 +417,7 @@ def remaining_steps(
remaining_steps.append((step, invocation_step))
return remaining_steps

def replacement_for_input(self, trans, step: "WorkflowStep", input_dict: Dict[str, Any]) -> Any:
def replacement_for_input(self, trans, step: "WorkflowStep", input_dict: Dict[str, Any]):
replacement: Union[
modules.NoReplacement,
model.DatasetCollectionInstance,
Expand Down Expand Up @@ -447,7 +449,7 @@ def replacement_for_input(self, trans, step: "WorkflowStep", input_dict: Dict[st
replacement = raw_to_galaxy(trans, step_input.default_value)
return replacement

def replacement_for_connection(self, connection: "WorkflowStepConnection", is_data: bool = True) -> Any:
def replacement_for_connection(self, connection: "WorkflowStepConnection", is_data: bool = True):
output_step_id = connection.output_step.id
output_name = connection.output_name
if output_step_id not in self.outputs:
Expand Down Expand Up @@ -530,7 +532,7 @@ def replacement_for_connection(self, connection: "WorkflowStepConnection", is_da

return replacement

def get_replacement_workflow_output(self, workflow_output: "WorkflowOutput") -> Any:
def get_replacement_workflow_output(self, workflow_output: "WorkflowOutput"):
step = workflow_output.workflow_step
output_name = workflow_output.output_name
step_outputs = self.outputs[step.id]
Expand All @@ -541,7 +543,10 @@ def get_replacement_workflow_output(self, workflow_output: "WorkflowOutput") ->
return step_outputs[output_name]

def set_outputs_for_input(
self, invocation_step: WorkflowInvocationStep, outputs: Any = None, already_persisted: bool = False
self,
invocation_step: WorkflowInvocationStep,
outputs: Optional[Dict[str, Any]] = None,
already_persisted: bool = False,
) -> None:
step = invocation_step.workflow_step

Expand Down
11 changes: 0 additions & 11 deletions test/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,6 @@ def celery_includes():
return ["galaxy.celery.tasks"]


def pytest_collection_finish(session):
try:
# This needs to be run after test collection
from .test_config_defaults import DRIVER

DRIVER.tear_down()
print("Galaxy test driver shutdown successful")
except Exception:
pass


@pytest.fixture
def temp_file():
with tempfile.NamedTemporaryFile(delete=True, mode="wb") as fh:
Expand Down
Loading