From 296bf342dac709cca36a3ca91216f0651c7d3762 Mon Sep 17 00:00:00 2001 From: John Chilton Date: Wed, 18 Dec 2024 21:59:31 -0500 Subject: [PATCH] Implement paired_or_unpaired collections... --- .../model/dataset_collections/registry.py | 2 + .../dataset_collections/types/__init__.py | 6 + .../model/dataset_collections/types/paired.py | 9 +- .../types/paired_or_unpaired.py | 45 ++++++ lib/galaxy/schema/schema.py | 2 + lib/galaxy/tools/__init__.py | 56 ++++++++ .../tools/split_paired_and_unpaired.xml | 132 ++++++++++++++++++ lib/galaxy/tools/wrappers.py | 4 + .../api/test_dataset_collections.py | 19 +++ .../tools/collection_paired_or_unpaired.xml | 51 +++++++ test/functional/tools/sample_tool_conf.xml | 3 +- 11 files changed, 326 insertions(+), 3 deletions(-) create mode 100644 lib/galaxy/model/dataset_collections/types/paired_or_unpaired.py create mode 100644 lib/galaxy/tools/split_paired_and_unpaired.xml create mode 100644 test/functional/tools/collection_paired_or_unpaired.xml diff --git a/lib/galaxy/model/dataset_collections/registry.py b/lib/galaxy/model/dataset_collections/registry.py index bd148edafd2d..ed75294f68e7 100644 --- a/lib/galaxy/model/dataset_collections/registry.py +++ b/lib/galaxy/model/dataset_collections/registry.py @@ -2,6 +2,7 @@ from .types import ( list, paired, + paired_or_unpaired, record, ) @@ -9,6 +10,7 @@ list.ListDatasetCollectionType, paired.PairedDatasetCollectionType, record.RecordDatasetCollectionType, + paired_or_unpaired.PairedOrUnpairedDatasetCollectionType, ] diff --git a/lib/galaxy/model/dataset_collections/types/__init__.py b/lib/galaxy/model/dataset_collections/types/__init__.py index c294f6957be6..831c07c9ca18 100644 --- a/lib/galaxy/model/dataset_collections/types/__init__.py +++ b/lib/galaxy/model/dataset_collections/types/__init__.py @@ -21,3 +21,9 @@ def generate_elements(self, dataset_instances: dict, **kwds): class BaseDatasetCollectionType(DatasetCollectionType): def _validation_failed(self, message): raise exceptions.ObjectAttributeInvalidException(message) + + def _ensure_dataset_with_identifier(self, dataset_instances: dict, name: str): + dataset_instance = dataset_instances.get(name) + if dataset_instance is None: + raise exceptions.ObjectAttributeInvalidException(f"An element with the identifier {name} is required to create this collection type") + return dataset_instance diff --git a/lib/galaxy/model/dataset_collections/types/paired.py b/lib/galaxy/model/dataset_collections/types/paired.py index e774ab67aace..e7677cee482b 100644 --- a/lib/galaxy/model/dataset_collections/types/paired.py +++ b/lib/galaxy/model/dataset_collections/types/paired.py @@ -1,3 +1,4 @@ +from galaxy.exceptions import RequestParameterInvalidException from galaxy.model import ( DatasetCollectionElement, HistoryDatasetAssociation, @@ -16,13 +17,17 @@ class PairedDatasetCollectionType(BaseDatasetCollectionType): collection_type = "paired" def generate_elements(self, dataset_instances, **kwds): - if forward_dataset := dataset_instances.get(FORWARD_IDENTIFIER): + num_datasets = len(dataset_instances) + if num_datasets != 2: + raise RequestParameterInvalidException(f"Incorrect number of datasets - 2 datasets exactly are required to create a single_or_paired collection") + + if forward_dataset := self._ensure_dataset_with_identifier(dataset_instances, FORWARD_IDENTIFIER): left_association = DatasetCollectionElement( element=forward_dataset, element_identifier=FORWARD_IDENTIFIER, ) yield left_association - if reverse_dataset := dataset_instances.get(REVERSE_IDENTIFIER): + if reverse_dataset := self._ensure_dataset_with_identifier(dataset_instances, REVERSE_IDENTIFIER): right_association = DatasetCollectionElement( element=reverse_dataset, element_identifier=REVERSE_IDENTIFIER, diff --git a/lib/galaxy/model/dataset_collections/types/paired_or_unpaired.py b/lib/galaxy/model/dataset_collections/types/paired_or_unpaired.py new file mode 100644 index 000000000000..bc736c72f2b8 --- /dev/null +++ b/lib/galaxy/model/dataset_collections/types/paired_or_unpaired.py @@ -0,0 +1,45 @@ +from galaxy.exceptions import RequestParameterInvalidException +from galaxy.model import ( + DatasetCollectionElement, + HistoryDatasetAssociation, +) +from . import BaseDatasetCollectionType +from .paired import ( + FORWARD_IDENTIFIER, + REVERSE_IDENTIFIER, +) + +SINGLETON_IDENTIFIER = "unpaired" + + +class PairedOrUnpairedDatasetCollectionType(BaseDatasetCollectionType): + """ + """ + + collection_type = "paired_or_unpaired" + + def generate_elements(self, dataset_instances, **kwds): + num_datasets = len(dataset_instances) + if num_datasets > 2 or num_datasets < 1: + raise RequestParameterInvalidException(f"Incorrect number of datasets - 1 or 2 datasets is required to create a paired_or_unpaired collection") + + if num_datasets == 2: + if forward_dataset := self._ensure_dataset_with_identifier(dataset_instances, FORWARD_IDENTIFIER): + left_association = DatasetCollectionElement( + element=forward_dataset, + element_identifier=FORWARD_IDENTIFIER, + ) + yield left_association + if reverse_dataset := self._ensure_dataset_with_identifier(dataset_instances, REVERSE_IDENTIFIER): + right_association = DatasetCollectionElement( + element=reverse_dataset, + element_identifier=REVERSE_IDENTIFIER, + ) + yield right_association + else: + if single_datasets := self._ensure_dataset_with_identifier(dataset_instances, SINGLETON_IDENTIFIER): + single_association = DatasetCollectionElement( + element=single_datasets, + element_identifier=SINGLETON_IDENTIFIER, + ) + yield single_association diff --git a/lib/galaxy/schema/schema.py b/lib/galaxy/schema/schema.py index 8c031a0fcfe1..010e7a78c278 100644 --- a/lib/galaxy/schema/schema.py +++ b/lib/galaxy/schema/schema.py @@ -33,6 +33,8 @@ from typing_extensions import ( Annotated, Literal, + NotRequired, + TypedDict, ) from galaxy.schema import partial_model diff --git a/lib/galaxy/tools/__init__.py b/lib/galaxy/tools/__init__.py index 1174db417b3b..aa39db19dabc 100644 --- a/lib/galaxy/tools/__init__.py +++ b/lib/galaxy/tools/__init__.py @@ -3534,6 +3534,62 @@ def produce_outputs(self, trans, out_data, output_collections, incoming, history ) +class SplitPairedAndUnpairedTool(DatabaseOperationTool): + tool_type = "split_paired_and_unpaired" + require_terminal_states = False + require_dataset_ok = False + + def produce_outputs(self, trans, out_data, output_collections, incoming, history, **kwds): + has_collection = incoming["input"] + if hasattr(has_collection, "element_type"): + # It is a DCE + collection = has_collection.element_object + else: + # It is an HDCA + collection = has_collection.collection + + collection_type = collection.collection_type + assert collection_type in ["list", "list:paired", "list:paired_or_unpaired"] + + unpaired_dce_copies = {} + paired_dce_copies = {} + paired_datasets = [] + + def _handle_unpaired(dce): + element_identifier = dce.element_identifier + assert getattr(dce.element_object, "history_content_type", None) == "dataset" + copied_value = dce.element_object.copy(copy_tags=dce.element_object.tags, flush=False) + unpaired_dce_copies[element_identifier] = copied_value + + def _handle_paired(dce): + element_identifier = dce.element_identifier + copied_value = dce.element_object.copy(flush=False) + paired_dce_copies[element_identifier] = copied_value + paired_datasets.append(copied_value.elements[0].element_object) + paired_datasets.append(copied_value.elements[1].element_object) + + if collection_type == "list": + for element in collection.elements: + _handle_unpaired(element) + elif collection_type == "list:paired": + for element in collection.elements: + _handle_paired(element) + elif collection_type == "list:paired_or_unpaired": + for element in collection.elements: + if getattr(element.element_object, "history_content_type", None) == "dataset": + _handle_unpaired(element) + else: + _handle_paired(element) + + self._add_datasets_to_history(history, unpaired_dce_copies.values()) + self._add_datasets_to_history(history, paired_datasets) + output_collections.create_collection( + self.outputs["output_unpaired"], "output_unpaired", elements=unpaired_dce_copies, propagate_hda_tags=False + ) + output_collections.create_collection( + self.outputs["output_paired"], "output_paired", elements=paired_dce_copies, propagate_hda_tags=False + ) + class ExtractDatasetCollectionTool(DatabaseOperationTool): tool_type = "extract_dataset" require_terminal_states = False diff --git a/lib/galaxy/tools/split_paired_and_unpaired.xml b/lib/galaxy/tools/split_paired_and_unpaired.xml new file mode 100644 index 000000000000..b05db5207fac --- /dev/null +++ b/lib/galaxy/tools/split_paired_and_unpaired.xml @@ -0,0 +1,132 @@ + + + + + + operation_2409 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/lib/galaxy/tools/wrappers.py b/lib/galaxy/tools/wrappers.py index 12d3d779cafc..6386448a8ac8 100644 --- a/lib/galaxy/tools/wrappers.py +++ b/lib/galaxy/tools/wrappers.py @@ -766,6 +766,10 @@ def serialize( include_collection_name=include_collection_name, ) + @property + def has_single_item(self) -> bool: + return self.__input_supplied and len(self.__element_instance_list) == 1 + @property def is_input_supplied(self) -> bool: return self.__input_supplied diff --git a/lib/galaxy_test/api/test_dataset_collections.py b/lib/galaxy_test/api/test_dataset_collections.py index 372d693d4fd0..d2b3e416f5a2 100644 --- a/lib/galaxy_test/api/test_dataset_collections.py +++ b/lib/galaxy_test/api/test_dataset_collections.py @@ -101,6 +101,25 @@ def test_create_list_of_new_pairs(self): pair_1_element_1 = pair_elements[0] assert pair_1_element_1["element_index"] == 0 + def test_create_paried_or_unpaired(self, history_id): + collection_name = "a singleton in a paired_or_unpaired collection" + contents = [ + ("unpaired", "1\t2\t3"), + ] + single_identifier = self.dataset_collection_populator.list_identifiers(history_id, contents) + payload = dict( + name=collection_name, + instance_type="history", + history_id=history_id, + element_identifiers=single_identifier, + collection_type="paired_or_unpaired", + ) + create_response = self._post("dataset_collections", payload, json=True) + dataset_collection = self._check_create_response(create_response) + assert dataset_collection["collection_type"] == "paired_or_unpaired" + returned_collections = dataset_collection["elements"] + assert len(returned_collections) == 1, dataset_collection + def test_create_record(self, history_id): contents = [ ("condition", "1\t2\t3"), diff --git a/test/functional/tools/collection_paired_or_unpaired.xml b/test/functional/tools/collection_paired_or_unpaired.xml new file mode 100644 index 000000000000..d250e1f859fa --- /dev/null +++ b/test/functional/tools/collection_paired_or_unpaired.xml @@ -0,0 +1,51 @@ + + + #if $f1.has_single_item: + cat $f1.unpaired >> $out1; + echo "Single item" + #else + cat $f1.forward $f1['reverse'] >> $out1; + echo "Paired items" + #end if + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/test/functional/tools/sample_tool_conf.xml b/test/functional/tools/sample_tool_conf.xml index 477fe69bde34..49e414eeabd6 100644 --- a/test/functional/tools/sample_tool_conf.xml +++ b/test/functional/tools/sample_tool_conf.xml @@ -213,6 +213,7 @@ + @@ -320,5 +321,5 @@ - +