Skip to content

Commit

Permalink
Type annotation improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
nsoranzo committed Mar 7, 2024
1 parent bc0d075 commit 1406a61
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 42 deletions.
3 changes: 3 additions & 0 deletions lib/galaxy/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7035,6 +7035,9 @@ def contains_collection(self, collection_id):
return len(results) > 0


HistoryItem: TypeAlias = Union[HistoryDatasetAssociation, HistoryDatasetCollectionAssociation]


class LibraryDatasetCollectionAssociation(Base, DatasetCollectionInstance, RepresentById):
"""Associates a DatasetCollection with a library folder."""

Expand Down
12 changes: 6 additions & 6 deletions lib/galaxy/model/store/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,10 @@

if TYPE_CHECKING:
from galaxy.managers.workflows import WorkflowContentsManager
from galaxy.model import ImplicitCollectionJobs
from galaxy.model import (
HistoryItem,
ImplicitCollectionJobs,
)
from galaxy.model.tags import GalaxyTagHandlerSession

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -1339,9 +1342,6 @@ def _copied_from_object_key(
return copied_from_object_key


HasHid = Union[model.HistoryDatasetAssociation, model.HistoryDatasetCollectionAssociation]


class ObjectImportTracker:
"""Keep track of new and existing imported objects.
Expand All @@ -1359,8 +1359,8 @@ class ObjectImportTracker:
hda_copied_from_sinks: Dict[ObjectKeyType, ObjectKeyType]
hdca_copied_from_sinks: Dict[ObjectKeyType, ObjectKeyType]
jobs_by_key: Dict[ObjectKeyType, model.Job]
requires_hid: List[HasHid]
copy_hid_for: Dict[HasHid, HasHid]
requires_hid: List["HistoryItem"]
copy_hid_for: Dict["HistoryItem", "HistoryItem"]

def __init__(self) -> None:
self.libraries_by_key = {}
Expand Down
4 changes: 2 additions & 2 deletions lib/galaxy/tool_util/client/staging.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def _fetch_post(self, payload: Dict[str, Any]) -> Dict[str, Any]:
return tool_response

@abc.abstractmethod
def _handle_job(self, job_response):
def _handle_job(self, job_response: Dict[str, Any]):
"""Implementer can decide if to wait for job(s) individually or not here."""

def stage(
Expand Down Expand Up @@ -288,7 +288,7 @@ def _post(self, api_path: str, payload: Dict[str, Any]) -> Dict[str, Any]:
assert response.status_code == 200, response.text
return response.json()

def _handle_job(self, job_response):
def _handle_job(self, job_response: Dict[str, Any]):
self.galaxy_interactor.wait_for_job(job_response["id"])

@property
Expand Down
25 changes: 12 additions & 13 deletions lib/galaxy/tool_util/verify/interactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,24 +385,23 @@ def compare(val, expected):
except KeyError:
raise Exception(f"Failed to verify dataset metadata, metadata key [{key}] was not found.")

def wait_for_job(self, job_id, history_id=None, maxseconds=DEFAULT_TOOL_TEST_WAIT):
def wait_for_job(self, job_id: str, history_id: Optional[str] = None, maxseconds=DEFAULT_TOOL_TEST_WAIT) -> None:
self.wait_for(lambda: self.__job_ready(job_id, history_id), maxseconds=maxseconds)

def wait_for(self, func, what="tool test run", **kwd):
def wait_for(self, func: Callable, what: str = "tool test run", **kwd) -> None:
walltime_exceeded = int(kwd.get("maxseconds", DEFAULT_TOOL_TEST_WAIT))
wait_on(func, what, walltime_exceeded)

def get_job_stdio(self, job_id):
job_stdio = self.__get_job_stdio(job_id).json()
return job_stdio
def get_job_stdio(self, job_id: str) -> Dict[str, Any]:
return self.__get_job_stdio(job_id).json()

def __get_job(self, job_id):
def __get_job(self, job_id: str) -> Response:
return self._get(f"jobs/{job_id}")

def __get_job_stdio(self, job_id):
def __get_job_stdio(self, job_id: str) -> Response:
return self._get(f"jobs/{job_id}?full=true")

def get_history(self, history_name="test_history"):
def get_history(self, history_name: str = "test_history") -> Optional[Dict[str, Any]]:
# Return the most recent non-deleted history matching the provided name
filters = urllib.parse.urlencode({"q": "name", "qv": history_name, "order": "update_time"})
response = self._get(f"histories?{filters}")
Expand Down Expand Up @@ -430,7 +429,7 @@ def test_history(
if cleanup and cleanup_callback is not None:
cleanup_callback(history_id)

def new_history(self, history_name="test_history", publish_history=False):
def new_history(self, history_name: str = "test_history", publish_history: bool = False) -> str:
create_response = self._post("histories", {"name": history_name})
try:
create_response.raise_for_status()
Expand All @@ -441,7 +440,7 @@ def new_history(self, history_name="test_history", publish_history=False):
self.publish_history(history_id)
return history_id

def publish_history(self, history_id):
def publish_history(self, history_id: str) -> None:
response = self._put(f"histories/{history_id}", json.dumps({"published": True}))
response.raise_for_status()

Expand Down Expand Up @@ -710,10 +709,10 @@ def __dictify_outputs(self, datasets_object) -> OutputsDict:
def output_hid(self, output_data):
return output_data["id"]

def delete_history(self, history):
def delete_history(self, history: str) -> None:
self._delete(f"histories/{history}")

def __job_ready(self, job_id, history_id=None):
def __job_ready(self, job_id: str, history_id: Optional[str] = None):
if job_id is None:
raise ValueError("__job_ready passed empty job_id")
try:
Expand Down Expand Up @@ -803,7 +802,7 @@ def __contents(self, history_id):
history_contents_response.raise_for_status()
return history_contents_response.json()

def _state_ready(self, job_id, error_msg):
def _state_ready(self, job_id: str, error_msg: str):
state_str = self.__get_job(job_id).json()["state"]
if state_str == "ok":
return True
Expand Down
6 changes: 3 additions & 3 deletions lib/galaxy/tool_util/verify/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def test_tools(
verify_kwds = (verify_kwds or {}).copy()
tool_test_start = dt.datetime.now()
history_created = False
test_history = None
test_history: Optional[str] = None
if not history_per_test_case:
if not history_name:
history_name = f"History for {results.suitename}"
Expand All @@ -192,8 +192,8 @@ def test_tools(
if log:
log.info(f"Using existing history with id '{test_history}', last updated: {history['update_time']}")
if not test_history:
history_created = True
test_history = galaxy_interactor.new_history(history_name=history_name, publish_history=publish_history)
history_created = True
if log:
log.info(f"History created with id '{test_history}'")
verify_kwds.update(
Expand Down Expand Up @@ -231,7 +231,7 @@ def test_tools(
log.info(f"Report written to '{destination}'")
log.info(results.info_message())
log.info(f"Total tool test time: {dt.datetime.now() - tool_test_start}")
if history_created and not no_history_cleanup:
if test_history and history_created and not no_history_cleanup:
galaxy_interactor.delete_history(test_history)


Expand Down
39 changes: 21 additions & 18 deletions lib/galaxy/webapps/galaxy/services/history_contents.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
List,
Optional,
Set,
TYPE_CHECKING,
Union,
)

Expand Down Expand Up @@ -127,10 +128,12 @@
ServiceBase,
)

if TYPE_CHECKING:
from galaxy.model import HistoryItem

log = logging.getLogger(__name__)

DatasetDetailsType = Union[Set[DecodedDatabaseIdField], Literal["all"]]
HistoryItemModel = Union[HistoryDatasetAssociation, HistoryDatasetCollectionAssociation]


class HistoryContentsIndexParams(Model):
Expand Down Expand Up @@ -690,7 +693,7 @@ def bulk_operation(
history = self.history_manager.get_mutable(history_id, trans.user, current_history=trans.history)
filters = self.history_contents_filters.parse_query_filters(filter_query_params)
self._validate_bulk_operation_params(payload, trans.user, trans)
contents: List[HistoryItemModel]
contents: List["HistoryItem"]
if payload.items:
contents = self._get_contents_by_item_list(
trans,
Expand Down Expand Up @@ -1328,7 +1331,7 @@ def _validate_bulk_operation_params(

def _apply_bulk_operation(
self,
contents: Iterable[HistoryItemModel],
contents: Iterable["HistoryItem"],
operation: HistoryContentItemOperation,
params: Optional[AnyBulkOperationParams],
trans: ProvidesHistoryContext,
Expand All @@ -1343,7 +1346,7 @@ def _apply_bulk_operation(
def _apply_operation_to_item(
self,
operation: HistoryContentItemOperation,
item: HistoryItemModel,
item: "HistoryItem",
params: Optional[AnyBulkOperationParams],
trans: ProvidesHistoryContext,
) -> Optional[BulkOperationItemError]:
Expand All @@ -1358,8 +1361,8 @@ def _apply_operation_to_item(

def _get_contents_by_item_list(
self, trans, history: History, items: List[HistoryContentItem]
) -> List[HistoryItemModel]:
contents: List[HistoryItemModel] = []
) -> List["HistoryItem"]:
contents: List["HistoryItem"] = []

dataset_items = filter(lambda item: item.history_content_type == HistoryContentType.dataset, items)
datasets_ids = map(lambda dataset: dataset.id, dataset_items)
Expand All @@ -1380,7 +1383,7 @@ def _get_contents_by_item_list(

class ItemOperation(Protocol):
def __call__(
self, item: HistoryItemModel, params: Optional[AnyBulkOperationParams], trans: ProvidesHistoryContext
self, item: "HistoryItem", params: Optional[AnyBulkOperationParams], trans: ProvidesHistoryContext
) -> None: ...


Expand Down Expand Up @@ -1414,24 +1417,24 @@ def __init__(
def apply(
self,
operation: HistoryContentItemOperation,
item: HistoryItemModel,
item: "HistoryItem",
params: Optional[AnyBulkOperationParams],
trans: ProvidesHistoryContext,
):
self._operation_map[operation](item, params, trans)

def _get_item_manager(self, item: HistoryItemModel):
def _get_item_manager(self, item: "HistoryItem"):
if isinstance(item, HistoryDatasetAssociation):
return self.hda_manager
return self.hdca_manager

def _hide(self, item: HistoryItemModel):
def _hide(self, item: "HistoryItem"):
item.visible = False

def _unhide(self, item: HistoryItemModel):
def _unhide(self, item: "HistoryItem"):
item.visible = True

def _delete(self, item: HistoryItemModel, trans: ProvidesHistoryContext):
def _delete(self, item: "HistoryItem", trans: ProvidesHistoryContext):
if isinstance(item, HistoryDatasetCollectionAssociation):
self.dataset_collection_manager.delete(trans, "history", item.id, recursive=True, purge=False)
else:
Expand All @@ -1440,13 +1443,13 @@ def _delete(self, item: HistoryItemModel, trans: ProvidesHistoryContext):
# otherwise the history will wait indefinitely for the items to be deleted
item.update()

def _undelete(self, item: HistoryItemModel):
def _undelete(self, item: "HistoryItem"):
if getattr(item, "purged", False):
raise exceptions.ItemDeletionException("This item has been permanently deleted and cannot be recovered.")
manager = self._get_item_manager(item)
manager.undelete(item, flush=self.flush)

def _purge(self, item: HistoryItemModel, trans: ProvidesHistoryContext):
def _purge(self, item: "HistoryItem", trans: ProvidesHistoryContext):
if getattr(item, "purged", False):
# TODO: remove this `update` when we can properly track the operation results to notify the history
item.update()
Expand All @@ -1456,7 +1459,7 @@ def _purge(self, item: HistoryItemModel, trans: ProvidesHistoryContext):
self.hda_manager.purge(item, flush=True, user=trans.user)

def _change_datatype(
self, item: HistoryItemModel, params: ChangeDatatypeOperationParams, trans: ProvidesHistoryContext
self, item: "HistoryItem", params: ChangeDatatypeOperationParams, trans: ProvidesHistoryContext
):
if isinstance(item, HistoryDatasetAssociation):
wrapped_task = self._change_item_datatype(item, params, trans)
Expand Down Expand Up @@ -1501,15 +1504,15 @@ def _change_item_datatype(
dataset_id=item.id, datatype=params.datatype, task_user_id=getattr(trans.user, "id", None)
)

def _change_dbkey(self, item: HistoryItemModel, params: ChangeDbkeyOperationParams):
def _change_dbkey(self, item: "HistoryItem", params: ChangeDbkeyOperationParams):
if isinstance(item, HistoryDatasetAssociation):
item.set_dbkey(params.dbkey)
elif isinstance(item, HistoryDatasetCollectionAssociation):
for dataset_instance in item.dataset_instances:
dataset_instance.set_dbkey(params.dbkey)

def _add_tags(self, trans: ProvidesUserContext, item: HistoryItemModel, params: TagOperationParams):
def _add_tags(self, trans: ProvidesUserContext, item: "HistoryItem", params: TagOperationParams):
trans.tag_handler.add_tags_from_list(trans.user, item, params.tags, flush=self.flush)

def _remove_tags(self, trans: ProvidesUserContext, item: HistoryItemModel, params: TagOperationParams):
def _remove_tags(self, trans: ProvidesUserContext, item: "HistoryItem", params: TagOperationParams):
trans.tag_handler.remove_tags_from_list(trans.user, item, params.tags, flush=self.flush)

0 comments on commit 1406a61

Please sign in to comment.