From 78d44158eb4dc4eeafb4fba4e178b41c1d875701 Mon Sep 17 00:00:00 2001 From: mvdbeek Date: Wed, 1 Nov 2023 11:54:25 +0100 Subject: [PATCH 1/2] add src_id_to_item helper Co-authored-by: Michael R. Crusoe --- lib/galaxy/tools/parameters/basic.py | 108 +++++++++++++-------------- 1 file changed, 54 insertions(+), 54 deletions(-) diff --git a/lib/galaxy/tools/parameters/basic.py b/lib/galaxy/tools/parameters/basic.py index 694fb751c3b5..d3dd190fcc1c 100644 --- a/lib/galaxy/tools/parameters/basic.py +++ b/lib/galaxy/tools/parameters/basic.py @@ -14,6 +14,7 @@ List, Optional, Tuple, + TYPE_CHECKING, Union, ) @@ -51,6 +52,11 @@ from .dataset_matcher import get_dataset_matcher_factory from .sanitize import ToolParameterSanitizer +if TYPE_CHECKING: + from sqlalchemy.orm import Session + + from galaxy.security.idencoding import IdEncodingHelper + log = logging.getLogger(__name__) @@ -1944,17 +1950,9 @@ def single_to_json(value): def to_python(self, value, app): def single_to_python(value): if isinstance(value, dict) and "src" in value: - id = value["id"] if isinstance(value["id"], int) else app.security.decode_id(value["id"]) - if value["src"] == "dce": - return session.get(DatasetCollectionElement, id) - elif value["src"] == "hdca": - return session.get(HistoryDatasetCollectionAssociation, id) - elif value["src"] == "ldda": - return session.get(LibraryDatasetDatasetAssociation, id) - else: - return session.get(HistoryDatasetAssociation, id) - - session = app.model.context + if value["src"] not in ("hda", "dce", "ldda", "hdca"): + raise ParameterValueError(f"Invalid value {value}", self.name) + return src_id_to_item(sa_session=app.model.context, security=app.security, value=value) if isinstance(value, dict) and "values" in value: if hasattr(self, "multiple") and self.multiple is True: @@ -1967,18 +1965,22 @@ def single_to_python(value): if value in none_values: return None if isinstance(value, str) and value.find(",") > -1: - return [session.get(HistoryDatasetAssociation, int(v)) for v in value.split(",") if v not in none_values] + return [ + app.model.context.get(HistoryDatasetAssociation, int(v)) + for v in value.split(",") + if v not in none_values + ] elif str(value).startswith("__collection_reduce__|"): decoded_id = str(value)[len("__collection_reduce__|") :] if not decoded_id.isdigit(): decoded_id = app.security.decode_id(decoded_id) - return session.get(HistoryDatasetCollectionAssociation, int(decoded_id)) + return app.model.context.get(HistoryDatasetCollectionAssociation, int(decoded_id)) elif str(value).startswith("dce:"): - return session.get(DatasetCollectionElement, int(value[len("dce:") :])) + return app.model.context.get(DatasetCollectionElement, int(value[len("dce:") :])) elif str(value).startswith("hdca:"): - return session.get(HistoryDatasetCollectionAssociation, int(value[len("hdca:") :])) + return app.model.context.get(HistoryDatasetCollectionAssociation, int(value[len("hdca:") :])) else: - return session.get(HistoryDatasetAssociation, int(value)) + return app.model.context.get(HistoryDatasetAssociation, int(value)) def validate(self, value, trans=None): def do_validate(v): @@ -2029,6 +2031,30 @@ def do_validate(v): raise ValueError("At most %d datasets are required for %s" % (self.max, self.name)) +def src_id_to_item( + sa_session: "Session", value: Dict[str, Any], security: "IdEncodingHelper" +) -> Union[ + DatasetCollectionElement, + HistoryDatasetAssociation, + HistoryDatasetCollectionAssociation, + LibraryDatasetDatasetAssociation, +]: + src_to_class = { + "hda": HistoryDatasetAssociation, + "ldda": LibraryDatasetDatasetAssociation, + "dce": DatasetCollectionElement, + "hdca": HistoryDatasetCollectionAssociation, + } + id_value = value["id"] + decoded_id = id_value if isinstance(id_value, int) else security.decode_id(id_value) + try: + item = sa_session.query(src_to_class[value["src"]]).get(decoded_id) + except KeyError: + raise ValueError(f"Unknown input source {value['src']} passed to job submission API.") + item.extra_params = {k: v for k, v in value.items() if k not in ("src", "id")} + return item + + class DataToolParameter(BaseDataToolParameter): # TODO, Nate: Make sure the following unit tests appropriately test the dataset security # components. Add as many additional tests as necessary. @@ -2105,24 +2131,13 @@ def from_json(self, value, trans, other_values=None): ] ] = [] if isinstance(value, list): - found_hdca = False + found_srcs = set() for single_value in value: if isinstance(single_value, dict) and "src" in single_value and "id" in single_value: - if single_value["src"] == "hda": - decoded_id = trans.security.decode_id(single_value["id"]) - rval.append(session.get(HistoryDatasetAssociation, decoded_id)) - elif single_value["src"] == "hdca": - found_hdca = True - decoded_id = trans.security.decode_id(single_value["id"]) - rval.append(session.get(HistoryDatasetCollectionAssociation, decoded_id)) - elif single_value["src"] == "ldda": - decoded_id = trans.security.decode_id(single_value["id"]) - rval.append(session.get(LibraryDatasetDatasetAssociation, decoded_id)) - elif single_value["src"] == "dce": - decoded_id = trans.security.decode_id(single_value["id"]) - rval.append(session.get(DatasetCollectionElement, decoded_id)) - else: - raise ValueError(f"Unknown input source {single_value['src']} passed to job submission API.") + found_srcs.add(single_value["src"]) + rval.append( + src_id_to_item(sa_session=trans.sa_session, value=single_value, security=trans.security) + ) elif isinstance( single_value, ( @@ -2139,31 +2154,16 @@ def from_json(self, value, trans, other_values=None): # support that for integer column types. log.warning("Encoded ID where unencoded ID expected.") single_value = trans.security.decode_id(single_value) - rval.append(session.get(HistoryDatasetAssociation, single_value)) - if found_hdca: - for val in rval: - if not isinstance(val, HistoryDatasetCollectionAssociation): - raise ParameterValueError( - "if collections are supplied to multiple data input parameter, only collections may be used", - self.name, - ) + rval.append(trans.sa_session.query(HistoryDatasetAssociation).get(single_value)) + if len(found_srcs) > 1 and "hdca" in found_srcs: + raise ParameterValueError( + "if collections are supplied to multiple data input parameter, only collections may be used", + self.name, + ) elif isinstance(value, (HistoryDatasetAssociation, LibraryDatasetDatasetAssociation)): rval.append(value) elif isinstance(value, dict) and "src" in value and "id" in value: - if value["src"] == "ldda": - decoded_id = trans.security.decode_id(value["id"]) - rval.append(trans.sa_session.query(LibraryDatasetDatasetAssociation).get(decoded_id)) - if value["src"] == "hda": - decoded_id = trans.security.decode_id(value["id"]) - rval.append(session.get(HistoryDatasetAssociation, decoded_id)) - elif value["src"] == "hdca": - decoded_id = trans.security.decode_id(value["id"]) - rval.append(session.get(HistoryDatasetCollectionAssociation, decoded_id)) - elif value["src"] == "dce": - decoded_id = trans.security.decode_id(value["id"]) - rval.append(session.get(DatasetCollectionElement, decoded_id)) - else: - raise ValueError(f"Unknown input source {value['src']} passed to job submission API.") + rval.append(src_id_to_item(sa_session=trans.sa_session, value=value, security=trans.security)) elif str(value).startswith("__collection_reduce__|"): encoded_ids = [v[len("__collection_reduce__|") :] for v in str(value).split(",")] decoded_ids = map(trans.security.decode_id, encoded_ids) From 1959e7920e36ab5c0a5504c96c642b8966536edf Mon Sep 17 00:00:00 2001 From: "Michael R. Crusoe" <1330696+mr-c@users.noreply.github.com> Date: Wed, 1 Nov 2023 12:40:10 +0100 Subject: [PATCH 2/2] new sqlalchemy 2.0 compatible syntax Co-authored-by: Marius van den Beek --- lib/galaxy/tools/parameters/basic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/galaxy/tools/parameters/basic.py b/lib/galaxy/tools/parameters/basic.py index d3dd190fcc1c..b178568c723e 100644 --- a/lib/galaxy/tools/parameters/basic.py +++ b/lib/galaxy/tools/parameters/basic.py @@ -2048,7 +2048,7 @@ def src_id_to_item( id_value = value["id"] decoded_id = id_value if isinstance(id_value, int) else security.decode_id(id_value) try: - item = sa_session.query(src_to_class[value["src"]]).get(decoded_id) + item = sa_session.get(src_to_class[value["src"]], decoded_id) except KeyError: raise ValueError(f"Unknown input source {value['src']} passed to job submission API.") item.extra_params = {k: v for k, v in value.items() if k not in ("src", "id")}