From dca7150ab172c23e7d4eeb038706689e134e9aea Mon Sep 17 00:00:00 2001 From: John Chilton Date: Mon, 18 May 2020 14:26:38 -0400 Subject: [PATCH] [WIP] Implement records - heterogenous dataset collections. Existing dataset colleciton types are meant to be homogenous - all datasets of the same time. This introduces CWL-style record dataset collections. --- lib/galaxy/jobs/__init__.py | 5 +- lib/galaxy/managers/collections.py | 12 ++- lib/galaxy/managers/collections_util.py | 1 + lib/galaxy/model/__init__.py | 15 ++- .../model/dataset_collections/builder.py | 22 +++- .../model/dataset_collections/registry.py | 11 +- .../dataset_collections/type_description.py | 9 +- .../dataset_collections/types/__init__.py | 2 +- .../model/dataset_collections/types/list.py | 4 +- .../model/dataset_collections/types/paired.py | 8 +- .../model/dataset_collections/types/record.py | 45 ++++++++ lib/galaxy/tool_util/parser/output_objects.py | 8 +- lib/galaxy/tools/actions/__init__.py | 5 +- lib/galaxy/workflow/modules.py | 2 +- .../api/test_dataset_collections.py | 101 ++++++++++++++++++ 15 files changed, 223 insertions(+), 27 deletions(-) create mode 100644 lib/galaxy/model/dataset_collections/types/record.py diff --git a/lib/galaxy/jobs/__init__.py b/lib/galaxy/jobs/__init__.py index c26aa11dbc0a..70b648ffa57d 100644 --- a/lib/galaxy/jobs/__init__.py +++ b/lib/galaxy/jobs/__init__.py @@ -1776,8 +1776,9 @@ def _finish_dataset( dataset.mark_unhidden() elif not purged: # If the tool was expected to set the extension, attempt to retrieve it - if dataset.ext == "auto": - dataset.extension = context.get("ext", "data") + context_ext = context.get("ext", "data") + if dataset.ext == "auto" or (dataset.ext == "data" and context_ext != "data"): + dataset.extension = context_ext dataset.init_meta(copy_from=dataset) # if a dataset was copied, it won't appear in our dictionary: # either use the metadata from originating output dataset, or call set_meta on the copies diff --git a/lib/galaxy/managers/collections.py b/lib/galaxy/managers/collections.py index fae6d0563347..e9042e3aac55 100644 --- a/lib/galaxy/managers/collections.py +++ b/lib/galaxy/managers/collections.py @@ -175,6 +175,7 @@ def create( flush=True, completed_job=None, output_name=None, + fields=None, ): """ PRECONDITION: security checks on ability to add to parent @@ -199,6 +200,7 @@ def create( hide_source_items=hide_source_items, copy_elements=copy_elements, history=history, + fields=fields, ) implicit_inputs = [] @@ -285,17 +287,20 @@ def create_dataset_collection( hide_source_items=None, copy_elements=False, history=None, + fields=None, ): # Make sure at least one of these is None. assert element_identifiers is None or elements is None - if element_identifiers is None and elements is None: raise RequestParameterInvalidException(ERROR_INVALID_ELEMENTS_SPECIFICATION) if not collection_type: raise RequestParameterInvalidException(ERROR_NO_COLLECTION_TYPE) - collection_type_description = self.collection_type_descriptions.for_collection_type(collection_type) + collection_type_description = self.collection_type_descriptions.for_collection_type( + collection_type, fields=fields + ) has_subcollections = collection_type_description.has_subcollections() + # If we have elements, this is an internal request, don't need to load # objects from identifiers. if elements is None: @@ -319,8 +324,9 @@ def create_dataset_collection( if elements is not self.ELEMENTS_UNINITIALIZED: type_plugin = collection_type_description.rank_type_plugin() - dataset_collection = builder.build_collection(type_plugin, elements) + dataset_collection = builder.build_collection(type_plugin, elements, fields=fields) else: + # TODO: Pass fields here - need test case first. dataset_collection = model.DatasetCollection(populated=False) dataset_collection.collection_type = collection_type return dataset_collection diff --git a/lib/galaxy/managers/collections_util.py b/lib/galaxy/managers/collections_util.py index 32feab23889d..7f129992c754 100644 --- a/lib/galaxy/managers/collections_util.py +++ b/lib/galaxy/managers/collections_util.py @@ -39,6 +39,7 @@ def api_payload_to_create_params(payload): name=payload.get("name", None), hide_source_items=string_as_bool(payload.get("hide_source_items", False)), copy_elements=string_as_bool(payload.get("copy_elements", False)), + fields=payload.get("fields", None), ) return params diff --git a/lib/galaxy/model/__init__.py b/lib/galaxy/model/__init__.py index 8921f6ceaf6c..fb65c01516fb 100644 --- a/lib/galaxy/model/__init__.py +++ b/lib/galaxy/model/__init__.py @@ -6545,12 +6545,21 @@ class DatasetCollection(Base, Dictifiable, UsesAnnotations, Serializable): populated_states = DatasetCollectionPopulatedState - def __init__(self, id=None, collection_type=None, populated=True, element_count=None): + def __init__( + self, + id=None, + collection_type=None, + populated=True, + element_count=None, + fields=None, + ): self.id = id self.collection_type = collection_type if not populated: self.populated_state = DatasetCollection.populated_states.NEW self.element_count = element_count + # TODO: persist fields... + self.fields = fields def _build_nested_collection_attributes_stmt( self, @@ -6725,6 +6734,10 @@ def populated_optimized(self): return self._populated_optimized + @property + def allow_implicit_mapping(self): + return self.collection_type != "record" + @property def populated(self): top_level_populated = self.populated_state == DatasetCollection.populated_states.OK diff --git a/lib/galaxy/model/dataset_collections/builder.py b/lib/galaxy/model/dataset_collections/builder.py index 2ae001f33a22..73af774904fe 100644 --- a/lib/galaxy/model/dataset_collections/builder.py +++ b/lib/galaxy/model/dataset_collections/builder.py @@ -4,25 +4,27 @@ from .type_description import COLLECTION_TYPE_DESCRIPTION_FACTORY -def build_collection(type, dataset_instances, collection=None, associated_identifiers=None): +def build_collection(type, dataset_instances, collection=None, associated_identifiers=None, fields=None): """ Build DatasetCollection with populated DatasetcollectionElement objects corresponding to the supplied dataset instances or throw exception if this is not a valid collection of the specified type. """ - dataset_collection = collection or model.DatasetCollection() + dataset_collection = collection or model.DatasetCollection(fields=fields) associated_identifiers = associated_identifiers or set() - set_collection_elements(dataset_collection, type, dataset_instances, associated_identifiers) + set_collection_elements(dataset_collection, type, dataset_instances, associated_identifiers, fields=fields) return dataset_collection -def set_collection_elements(dataset_collection, type, dataset_instances, associated_identifiers): +def set_collection_elements(dataset_collection, type, dataset_instances, associated_identifiers, fields=None): new_element_keys = OrderedSet(dataset_instances.keys()) - associated_identifiers new_dataset_instances = {k: dataset_instances[k] for k in new_element_keys} dataset_collection.element_count = dataset_collection.element_count or 0 element_index = dataset_collection.element_count elements = [] - for element in type.generate_elements(new_dataset_instances): + if fields == "auto": + fields = guess_fields(dataset_instances) + for element in type.generate_elements(new_dataset_instances, fields=fields): element.element_index = element_index add_object_to_object_session(element, dataset_collection) element.collection = dataset_collection @@ -35,6 +37,16 @@ def set_collection_elements(dataset_collection, type, dataset_instances, associa return dataset_collection +def guess_fields(dataset_instances): + fields = [] + for identifier, element in dataset_instances.items(): + # TODO: Make generic enough to handle nested record types. + assert element.history_content_type == "dataset" + fields.append({"type": "File", "name": identifier}) + + return fields + + class CollectionBuilder: """Purely functional builder pattern for building a dataset collection.""" diff --git a/lib/galaxy/model/dataset_collections/registry.py b/lib/galaxy/model/dataset_collections/registry.py index 9c849dfdad6f..bd148edafd2d 100644 --- a/lib/galaxy/model/dataset_collections/registry.py +++ b/lib/galaxy/model/dataset_collections/registry.py @@ -2,9 +2,14 @@ from .types import ( list, paired, + record, ) -PLUGIN_CLASSES = [list.ListDatasetCollectionType, paired.PairedDatasetCollectionType] +PLUGIN_CLASSES = [ + list.ListDatasetCollectionType, + paired.PairedDatasetCollectionType, + record.RecordDatasetCollectionType, +] class DatasetCollectionTypesRegistry: @@ -14,13 +19,13 @@ def __init__(self): def get(self, plugin_type): return self.__plugins[plugin_type] - def prototype(self, plugin_type): + def prototype(self, plugin_type, fields=None): plugin_type_object = self.get(plugin_type) if not hasattr(plugin_type_object, "prototype_elements"): raise Exception(f"Cannot pre-determine structure for collection of type {plugin_type}") dataset_collection = model.DatasetCollection() - for e in plugin_type_object.prototype_elements(): + for e in plugin_type_object.prototype_elements(fields=fields): e.collection = dataset_collection return dataset_collection diff --git a/lib/galaxy/model/dataset_collections/type_description.py b/lib/galaxy/model/dataset_collections/type_description.py index cade102453ca..172ac1976b27 100644 --- a/lib/galaxy/model/dataset_collections/type_description.py +++ b/lib/galaxy/model/dataset_collections/type_description.py @@ -9,9 +9,9 @@ def __init__(self, type_registry=DATASET_COLLECTION_TYPES_REGISTRY): # I think. self.type_registry = type_registry - def for_collection_type(self, collection_type): + def for_collection_type(self, collection_type, fields=None): assert collection_type is not None - return CollectionTypeDescription(collection_type, self) + return CollectionTypeDescription(collection_type, self, fields=fields) class CollectionTypeDescription: @@ -47,12 +47,15 @@ class CollectionTypeDescription: collection_type: str - def __init__(self, collection_type: Union[str, "CollectionTypeDescription"], collection_type_description_factory): + def __init__( + self, collection_type: Union[str, "CollectionTypeDescription"], collection_type_description_factory, fields=None + ): if isinstance(collection_type, CollectionTypeDescription): self.collection_type = collection_type.collection_type else: self.collection_type = collection_type self.collection_type_description_factory = collection_type_description_factory + self.fields = fields self.__has_subcollections = self.collection_type.find(":") > 0 def child_collection_type(self): diff --git a/lib/galaxy/model/dataset_collections/types/__init__.py b/lib/galaxy/model/dataset_collections/types/__init__.py index bfcf7bae79a6..c294f6957be6 100644 --- a/lib/galaxy/model/dataset_collections/types/__init__.py +++ b/lib/galaxy/model/dataset_collections/types/__init__.py @@ -11,7 +11,7 @@ class DatasetCollectionType(metaclass=ABCMeta): @abstractmethod - def generate_elements(self, dataset_instances): + def generate_elements(self, dataset_instances: dict, **kwds): """Generate DatasetCollectionElements with corresponding to the supplied dataset instances or throw exception if this is not a valid collection of the specified type. diff --git a/lib/galaxy/model/dataset_collections/types/list.py b/lib/galaxy/model/dataset_collections/types/list.py index 18ce4db76537..d4421d009c34 100644 --- a/lib/galaxy/model/dataset_collections/types/list.py +++ b/lib/galaxy/model/dataset_collections/types/list.py @@ -7,8 +7,8 @@ class ListDatasetCollectionType(BaseDatasetCollectionType): collection_type = "list" - def generate_elements(self, elements): - for identifier, element in elements.items(): + def generate_elements(self, dataset_instances, **kwds): + for identifier, element in dataset_instances.items(): association = DatasetCollectionElement( element=element, element_identifier=identifier, diff --git a/lib/galaxy/model/dataset_collections/types/paired.py b/lib/galaxy/model/dataset_collections/types/paired.py index 4ae95a1442a2..e774ab67aace 100644 --- a/lib/galaxy/model/dataset_collections/types/paired.py +++ b/lib/galaxy/model/dataset_collections/types/paired.py @@ -15,21 +15,21 @@ class PairedDatasetCollectionType(BaseDatasetCollectionType): collection_type = "paired" - def generate_elements(self, elements): - if forward_dataset := elements.get(FORWARD_IDENTIFIER): + def generate_elements(self, dataset_instances, **kwds): + if forward_dataset := dataset_instances.get(FORWARD_IDENTIFIER): left_association = DatasetCollectionElement( element=forward_dataset, element_identifier=FORWARD_IDENTIFIER, ) yield left_association - if reverse_dataset := elements.get(REVERSE_IDENTIFIER): + if reverse_dataset := dataset_instances.get(REVERSE_IDENTIFIER): right_association = DatasetCollectionElement( element=reverse_dataset, element_identifier=REVERSE_IDENTIFIER, ) yield right_association - def prototype_elements(self): + def prototype_elements(self, **kwds): left_association = DatasetCollectionElement( element=HistoryDatasetAssociation(), element_identifier=FORWARD_IDENTIFIER, diff --git a/lib/galaxy/model/dataset_collections/types/record.py b/lib/galaxy/model/dataset_collections/types/record.py new file mode 100644 index 000000000000..193509f439ee --- /dev/null +++ b/lib/galaxy/model/dataset_collections/types/record.py @@ -0,0 +1,45 @@ +from galaxy.exceptions import RequestParameterMissingException +from galaxy.model import ( + DatasetCollectionElement, + HistoryDatasetAssociation, +) +from ..types import BaseDatasetCollectionType + + +class RecordDatasetCollectionType(BaseDatasetCollectionType): + """Arbitrary CWL-style record type.""" + + collection_type = "record" + + def generate_elements(self, dataset_instances, **kwds): + fields = kwds.get("fields", None) + if fields is None: + raise RequestParameterMissingException("Missing or null parameter 'fields' required for record types.") + if len(dataset_instances) != len(fields): + self._validation_failed("Supplied element do not match fields.") + index = 0 + for identifier, element in dataset_instances.items(): + field = fields[index] + if field["name"] != identifier: + self._validation_failed("Supplied element do not match fields.") + + # TODO: validate type and such. + association = DatasetCollectionElement( + element=element, + element_identifier=identifier, + ) + yield association + index += 1 + + def prototype_elements(self, fields=None, **kwds): + if fields is None: + raise RequestParameterMissingException("Missing or null parameter 'fields' required for record types.") + for field in fields: + name = field.get("name", None) + assert name + assert field.get("type", "File") # NS: this assert doesn't make sense as it is + field_dataset = DatasetCollectionElement( + element=HistoryDatasetAssociation(), + element_identifier=name, + ) + yield field_dataset diff --git a/lib/galaxy/tool_util/parser/output_objects.py b/lib/galaxy/tool_util/parser/output_objects.py index 63148c1fb946..7825d6308197 100644 --- a/lib/galaxy/tool_util/parser/output_objects.py +++ b/lib/galaxy/tool_util/parser/output_objects.py @@ -402,12 +402,14 @@ def __init__( collection_type_from_rules: Optional[str] = None, structured_like: Optional[str] = None, dataset_collector_descriptions: Optional[List[DatasetCollectionDescription]] = None, + fields=None, ) -> None: self.collection_type = collection_type self.collection_type_source = collection_type_source self.collection_type_from_rules = collection_type_from_rules self.structured_like = structured_like self.dataset_collector_descriptions = dataset_collector_descriptions or [] + self.fields = fields if collection_type and collection_type_source: raise ValueError("Cannot set both type and type_source on collection output.") if ( @@ -424,6 +426,10 @@ def __init__( raise ValueError( "Cannot specify dynamic structure (discover_datasets) and collection type attributes structured_like or collection_type_from_rules." ) + if collection_type == "record" and fields is None: + raise ValueError("If record outputs are defined, fields must be defined as well.") + if fields is not None and collection_type != "record": + raise ValueError("If fields are specified for outputs, the collection type must be record.") self.dynamic = bool(dataset_collector_descriptions) def collection_prototype(self, inputs, type_registry): @@ -433,7 +439,7 @@ def collection_prototype(self, inputs, type_registry): else: collection_type = self.collection_type assert collection_type - collection_prototype = type_registry.prototype(collection_type) + collection_prototype = type_registry.prototype(collection_type, fields=self.fields) collection_prototype.collection_type = collection_type return collection_prototype diff --git a/lib/galaxy/tools/actions/__init__.py b/lib/galaxy/tools/actions/__init__.py index 841eea988d49..f7a2138795a3 100644 --- a/lib/galaxy/tools/actions/__init__.py +++ b/lib/galaxy/tools/actions/__init__.py @@ -679,7 +679,10 @@ def handle_output(name, output, hidden=None): assert not element_identifiers # known_outputs must have been empty element_kwds = dict(elements=collections_manager.ELEMENTS_UNINITIALIZED) else: - element_kwds = dict(element_identifiers=element_identifiers) + element_kwds = dict( + element_identifiers=element_identifiers, + fields=output.structure.fields, + ) output_collections.create_collection( output=output, name=name, completed_job=completed_job, **element_kwds ) diff --git a/lib/galaxy/workflow/modules.py b/lib/galaxy/workflow/modules.py index 3d0f267ff3bb..ae5f8e31e920 100644 --- a/lib/galaxy/workflow/modules.py +++ b/lib/galaxy/workflow/modules.py @@ -551,7 +551,7 @@ def _find_collections_to_match(self, progress: "WorkflowProgress", step, all_inp for input_dict in all_inputs: name = input_dict["name"] data = progress.replacement_for_input(self.trans, step, input_dict) - can_map_over = hasattr(data, "collection") # and data.collection.allow_implicit_mapping + can_map_over = hasattr(data, "collection") and data.collection.allow_implicit_mapping if not can_map_over: continue diff --git a/lib/galaxy_test/api/test_dataset_collections.py b/lib/galaxy_test/api/test_dataset_collections.py index d7710c57b2fa..a4ba01877e5e 100644 --- a/lib/galaxy_test/api/test_dataset_collections.py +++ b/lib/galaxy_test/api/test_dataset_collections.py @@ -1,3 +1,4 @@ +import json import zipfile from io import BytesIO from typing import List @@ -100,6 +101,106 @@ def test_create_list_of_new_pairs(self): pair_1_element_1 = pair_elements[0] assert pair_1_element_1["element_index"] == 0 + def test_create_record(self, history_id): + contents = [ + ("condition", "1\t2\t3"), + ("control1", "4\t5\t6"), + ("control2", "7\t8\t9"), + ] + record_identifiers = self.dataset_collection_populator.list_identifiers(history_id, contents) + fields = [ + {"name": "condition", "type": "File"}, + {"name": "control1", "type": "File"}, + {"name": "control2", "type": "File"}, + ] + payload = dict( + name="a record", + instance_type="history", + history_id=history_id, + element_identifiers=json.dumps(record_identifiers), + collection_type="record", + fields=json.dumps(fields), + ) + create_response = self._post("dataset_collections", payload) + dataset_collection = self._check_create_response(create_response) + assert dataset_collection["collection_type"] == "record" + assert dataset_collection["name"] == "a record" + returned_collections = dataset_collection["elements"] + assert len(returned_collections) == 3, dataset_collection + record_pos_0_element = returned_collections[0] + self._assert_has_keys(record_pos_0_element, "element_index") + record_pos_0_object = record_pos_0_element["object"] + self._assert_has_keys(record_pos_0_object, "name", "history_content_type") + + def test_record_requires_fields(self, history_id): + contents = [ + ("condition", "1\t2\t3"), + ("control1", "4\t5\t6"), + ("control2", "7\t8\t9"), + ] + record_identifiers = self.dataset_collection_populator.list_identifiers(history_id, contents) + payload = dict( + name="a record", + instance_type="history", + history_id=history_id, + element_identifiers=json.dumps(record_identifiers), + collection_type="record", + ) + create_response = self._post("dataset_collections", payload) + self._assert_status_code_is(create_response, 400) + + def test_record_auto_fields(self, history_id): + contents = [ + ("condition", "1\t2\t3"), + ("control1", "4\t5\t6"), + ("control2", "7\t8\t9"), + ] + record_identifiers = self.dataset_collection_populator.list_identifiers(history_id, contents) + payload = dict( + name="a record", + instance_type="history", + history_id=history_id, + element_identifiers=json.dumps(record_identifiers), + collection_type="record", + fields="auto", + ) + create_response = self._post("dataset_collections", payload) + self._check_create_response(create_response) + + def test_record_field_validation(self, history_id): + contents = [ + ("condition", "1\t2\t3"), + ("control1", "4\t5\t6"), + ("control2", "7\t8\t9"), + ] + record_identifiers = self.dataset_collection_populator.list_identifiers(history_id, contents) + too_few_fields = [ + {"name": "condition", "type": "File"}, + {"name": "control1", "type": "File"}, + ] + too_many_fields = [ + {"name": "condition", "type": "File"}, + {"name": "control1", "type": "File"}, + {"name": "control2", "type": "File"}, + {"name": "control3", "type": "File"}, + ] + wrong_name_fields = [ + {"name": "condition", "type": "File"}, + {"name": "control1", "type": "File"}, + {"name": "control3", "type": "File"}, + ] + for fields in [too_few_fields, too_many_fields, wrong_name_fields]: + payload = dict( + name="a record", + instance_type="history", + history_id=history_id, + element_identifiers=json.dumps(record_identifiers), + collection_type="record", + fields=json.dumps(fields), + ) + create_response = self._post("dataset_collections", payload) + self._assert_status_code_is(create_response, 400) + def test_list_download(self): with self.dataset_populator.test_history(require_new=False) as history_id: fetch_response = self.dataset_collection_populator.create_list_in_history(