Skip to content

Commit

Permalink
Merge pull request galaxyproject#16953 from common-workflow-lab/src_i…
Browse files Browse the repository at this point in the history
…d_to_item

Add helper to get dataset or collection via ``src`` and ``id``
  • Loading branch information
mvdbeek authored Nov 1, 2023
2 parents 50d6e81 + 1959e79 commit af386d2
Showing 1 changed file with 54 additions and 54 deletions.
108 changes: 54 additions & 54 deletions lib/galaxy/tools/parameters/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
List,
Optional,
Tuple,
TYPE_CHECKING,
Union,
)

Expand Down Expand Up @@ -51,6 +52,11 @@
from .dataset_matcher import get_dataset_matcher_factory
from .sanitize import ToolParameterSanitizer

if TYPE_CHECKING:
from sqlalchemy.orm import Session

from galaxy.security.idencoding import IdEncodingHelper

log = logging.getLogger(__name__)


Expand Down Expand Up @@ -1944,17 +1950,9 @@ def single_to_json(value):
def to_python(self, value, app):
def single_to_python(value):
if isinstance(value, dict) and "src" in value:
id = value["id"] if isinstance(value["id"], int) else app.security.decode_id(value["id"])
if value["src"] == "dce":
return session.get(DatasetCollectionElement, id)
elif value["src"] == "hdca":
return session.get(HistoryDatasetCollectionAssociation, id)
elif value["src"] == "ldda":
return session.get(LibraryDatasetDatasetAssociation, id)
else:
return session.get(HistoryDatasetAssociation, id)

session = app.model.context
if value["src"] not in ("hda", "dce", "ldda", "hdca"):
raise ParameterValueError(f"Invalid value {value}", self.name)
return src_id_to_item(sa_session=app.model.context, security=app.security, value=value)

if isinstance(value, dict) and "values" in value:
if hasattr(self, "multiple") and self.multiple is True:
Expand All @@ -1967,18 +1965,22 @@ def single_to_python(value):
if value in none_values:
return None
if isinstance(value, str) and value.find(",") > -1:
return [session.get(HistoryDatasetAssociation, int(v)) for v in value.split(",") if v not in none_values]
return [
app.model.context.get(HistoryDatasetAssociation, int(v))
for v in value.split(",")
if v not in none_values
]
elif str(value).startswith("__collection_reduce__|"):
decoded_id = str(value)[len("__collection_reduce__|") :]
if not decoded_id.isdigit():
decoded_id = app.security.decode_id(decoded_id)
return session.get(HistoryDatasetCollectionAssociation, int(decoded_id))
return app.model.context.get(HistoryDatasetCollectionAssociation, int(decoded_id))
elif str(value).startswith("dce:"):
return session.get(DatasetCollectionElement, int(value[len("dce:") :]))
return app.model.context.get(DatasetCollectionElement, int(value[len("dce:") :]))
elif str(value).startswith("hdca:"):
return session.get(HistoryDatasetCollectionAssociation, int(value[len("hdca:") :]))
return app.model.context.get(HistoryDatasetCollectionAssociation, int(value[len("hdca:") :]))
else:
return session.get(HistoryDatasetAssociation, int(value))
return app.model.context.get(HistoryDatasetAssociation, int(value))

def validate(self, value, trans=None):
def do_validate(v):
Expand Down Expand Up @@ -2029,6 +2031,30 @@ def do_validate(v):
raise ValueError("At most %d datasets are required for %s" % (self.max, self.name))


def src_id_to_item(
sa_session: "Session", value: Dict[str, Any], security: "IdEncodingHelper"
) -> Union[
DatasetCollectionElement,
HistoryDatasetAssociation,
HistoryDatasetCollectionAssociation,
LibraryDatasetDatasetAssociation,
]:
src_to_class = {
"hda": HistoryDatasetAssociation,
"ldda": LibraryDatasetDatasetAssociation,
"dce": DatasetCollectionElement,
"hdca": HistoryDatasetCollectionAssociation,
}
id_value = value["id"]
decoded_id = id_value if isinstance(id_value, int) else security.decode_id(id_value)
try:
item = sa_session.get(src_to_class[value["src"]], decoded_id)
except KeyError:
raise ValueError(f"Unknown input source {value['src']} passed to job submission API.")
item.extra_params = {k: v for k, v in value.items() if k not in ("src", "id")}
return item


class DataToolParameter(BaseDataToolParameter):
# TODO, Nate: Make sure the following unit tests appropriately test the dataset security
# components. Add as many additional tests as necessary.
Expand Down Expand Up @@ -2105,24 +2131,13 @@ def from_json(self, value, trans, other_values=None):
]
] = []
if isinstance(value, list):
found_hdca = False
found_srcs = set()
for single_value in value:
if isinstance(single_value, dict) and "src" in single_value and "id" in single_value:
if single_value["src"] == "hda":
decoded_id = trans.security.decode_id(single_value["id"])
rval.append(session.get(HistoryDatasetAssociation, decoded_id))
elif single_value["src"] == "hdca":
found_hdca = True
decoded_id = trans.security.decode_id(single_value["id"])
rval.append(session.get(HistoryDatasetCollectionAssociation, decoded_id))
elif single_value["src"] == "ldda":
decoded_id = trans.security.decode_id(single_value["id"])
rval.append(session.get(LibraryDatasetDatasetAssociation, decoded_id))
elif single_value["src"] == "dce":
decoded_id = trans.security.decode_id(single_value["id"])
rval.append(session.get(DatasetCollectionElement, decoded_id))
else:
raise ValueError(f"Unknown input source {single_value['src']} passed to job submission API.")
found_srcs.add(single_value["src"])
rval.append(
src_id_to_item(sa_session=trans.sa_session, value=single_value, security=trans.security)
)
elif isinstance(
single_value,
(
Expand All @@ -2139,31 +2154,16 @@ def from_json(self, value, trans, other_values=None):
# support that for integer column types.
log.warning("Encoded ID where unencoded ID expected.")
single_value = trans.security.decode_id(single_value)
rval.append(session.get(HistoryDatasetAssociation, single_value))
if found_hdca:
for val in rval:
if not isinstance(val, HistoryDatasetCollectionAssociation):
raise ParameterValueError(
"if collections are supplied to multiple data input parameter, only collections may be used",
self.name,
)
rval.append(trans.sa_session.query(HistoryDatasetAssociation).get(single_value))
if len(found_srcs) > 1 and "hdca" in found_srcs:
raise ParameterValueError(
"if collections are supplied to multiple data input parameter, only collections may be used",
self.name,
)
elif isinstance(value, (HistoryDatasetAssociation, LibraryDatasetDatasetAssociation)):
rval.append(value)
elif isinstance(value, dict) and "src" in value and "id" in value:
if value["src"] == "ldda":
decoded_id = trans.security.decode_id(value["id"])
rval.append(trans.sa_session.query(LibraryDatasetDatasetAssociation).get(decoded_id))
if value["src"] == "hda":
decoded_id = trans.security.decode_id(value["id"])
rval.append(session.get(HistoryDatasetAssociation, decoded_id))
elif value["src"] == "hdca":
decoded_id = trans.security.decode_id(value["id"])
rval.append(session.get(HistoryDatasetCollectionAssociation, decoded_id))
elif value["src"] == "dce":
decoded_id = trans.security.decode_id(value["id"])
rval.append(session.get(DatasetCollectionElement, decoded_id))
else:
raise ValueError(f"Unknown input source {value['src']} passed to job submission API.")
rval.append(src_id_to_item(sa_session=trans.sa_session, value=value, security=trans.security))
elif str(value).startswith("__collection_reduce__|"):
encoded_ids = [v[len("__collection_reduce__|") :] for v in str(value).split(",")]
decoded_ids = map(trans.security.decode_id, encoded_ids)
Expand Down

0 comments on commit af386d2

Please sign in to comment.