Skip to content

Commit

Permalink
[WIP] Implement records - heterogenous dataset collections.
Browse files Browse the repository at this point in the history
Existing dataset colleciton types are meant to be homogenous - all datasets of the same time. This introduces CWL-style record dataset collections.
  • Loading branch information
jmchilton committed Dec 9, 2024
1 parent 39e38c9 commit 5c456d1
Show file tree
Hide file tree
Showing 15 changed files with 223 additions and 27 deletions.
5 changes: 3 additions & 2 deletions lib/galaxy/jobs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1776,8 +1776,9 @@ def _finish_dataset(
dataset.mark_unhidden()
elif not purged:
# If the tool was expected to set the extension, attempt to retrieve it
if dataset.ext == "auto":
dataset.extension = context.get("ext", "data")
context_ext = context.get("ext", "data")
if dataset.ext == "auto" or (dataset.ext == "data" and context_ext != "data"):
dataset.extension = context_ext
dataset.init_meta(copy_from=dataset)
# if a dataset was copied, it won't appear in our dictionary:
# either use the metadata from originating output dataset, or call set_meta on the copies
Expand Down
12 changes: 9 additions & 3 deletions lib/galaxy/managers/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def create(
flush=True,
completed_job=None,
output_name=None,
fields=None,
):
"""
PRECONDITION: security checks on ability to add to parent
Expand All @@ -199,6 +200,7 @@ def create(
hide_source_items=hide_source_items,
copy_elements=copy_elements,
history=history,
fields=fields,
)

implicit_inputs = []
Expand Down Expand Up @@ -285,17 +287,20 @@ def create_dataset_collection(
hide_source_items=None,
copy_elements=False,
history=None,
fields=None,
):
# Make sure at least one of these is None.
assert element_identifiers is None or elements is None

if element_identifiers is None and elements is None:
raise RequestParameterInvalidException(ERROR_INVALID_ELEMENTS_SPECIFICATION)
if not collection_type:
raise RequestParameterInvalidException(ERROR_NO_COLLECTION_TYPE)

collection_type_description = self.collection_type_descriptions.for_collection_type(collection_type)
collection_type_description = self.collection_type_descriptions.for_collection_type(
collection_type, fields=fields
)
has_subcollections = collection_type_description.has_subcollections()

# If we have elements, this is an internal request, don't need to load
# objects from identifiers.
if elements is None:
Expand All @@ -319,8 +324,9 @@ def create_dataset_collection(

if elements is not self.ELEMENTS_UNINITIALIZED:
type_plugin = collection_type_description.rank_type_plugin()
dataset_collection = builder.build_collection(type_plugin, elements)
dataset_collection = builder.build_collection(type_plugin, elements, fields=fields)
else:
# TODO: Pass fields here - need test case first.
dataset_collection = model.DatasetCollection(populated=False)
dataset_collection.collection_type = collection_type
return dataset_collection
Expand Down
1 change: 1 addition & 0 deletions lib/galaxy/managers/collections_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def api_payload_to_create_params(payload):
name=payload.get("name", None),
hide_source_items=string_as_bool(payload.get("hide_source_items", False)),
copy_elements=string_as_bool(payload.get("copy_elements", False)),
fields=payload.get("fields", None),
)
return params

Expand Down
15 changes: 14 additions & 1 deletion lib/galaxy/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6533,12 +6533,21 @@ class DatasetCollection(Base, Dictifiable, UsesAnnotations, Serializable):

populated_states = DatasetCollectionPopulatedState

def __init__(self, id=None, collection_type=None, populated=True, element_count=None):
def __init__(
self,
id=None,
collection_type=None,
populated=True,
element_count=None,
fields=None,
):
self.id = id
self.collection_type = collection_type
if not populated:
self.populated_state = DatasetCollection.populated_states.NEW
self.element_count = element_count
# TODO: persist fields...
self.fields = fields

def _build_nested_collection_attributes_stmt(
self,
Expand Down Expand Up @@ -6713,6 +6722,10 @@ def populated_optimized(self):

return self._populated_optimized

@property
def allow_implicit_mapping(self):
return self.collection_type != "record"

@property
def populated(self):
top_level_populated = self.populated_state == DatasetCollection.populated_states.OK
Expand Down
22 changes: 17 additions & 5 deletions lib/galaxy/model/dataset_collections/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,27 @@
from .type_description import COLLECTION_TYPE_DESCRIPTION_FACTORY


def build_collection(type, dataset_instances, collection=None, associated_identifiers=None):
def build_collection(type, dataset_instances, collection=None, associated_identifiers=None, fields=None):
"""
Build DatasetCollection with populated DatasetcollectionElement objects
corresponding to the supplied dataset instances or throw exception if
this is not a valid collection of the specified type.
"""
dataset_collection = collection or model.DatasetCollection()
dataset_collection = collection or model.DatasetCollection(fields=fields)
associated_identifiers = associated_identifiers or set()
set_collection_elements(dataset_collection, type, dataset_instances, associated_identifiers)
set_collection_elements(dataset_collection, type, dataset_instances, associated_identifiers, fields=fields)
return dataset_collection


def set_collection_elements(dataset_collection, type, dataset_instances, associated_identifiers):
def set_collection_elements(dataset_collection, type, dataset_instances, associated_identifiers, fields=None):
new_element_keys = OrderedSet(dataset_instances.keys()) - associated_identifiers
new_dataset_instances = {k: dataset_instances[k] for k in new_element_keys}
dataset_collection.element_count = dataset_collection.element_count or 0
element_index = dataset_collection.element_count
elements = []
for element in type.generate_elements(new_dataset_instances):
if fields == "auto":
fields = guess_fields(dataset_instances)
for element in type.generate_elements(new_dataset_instances, fields=fields):
element.element_index = element_index
add_object_to_object_session(element, dataset_collection)
element.collection = dataset_collection
Expand All @@ -35,6 +37,16 @@ def set_collection_elements(dataset_collection, type, dataset_instances, associa
return dataset_collection


def guess_fields(dataset_instances):
fields = []
for identifier, element in dataset_instances.items():
# TODO: Make generic enough to handle nested record types.
assert element.history_content_type == "dataset"
fields.append({"type": "File", "name": identifier})

return fields


class CollectionBuilder:
"""Purely functional builder pattern for building a dataset collection."""

Expand Down
11 changes: 8 additions & 3 deletions lib/galaxy/model/dataset_collections/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@
from .types import (
list,
paired,
record,
)

PLUGIN_CLASSES = [list.ListDatasetCollectionType, paired.PairedDatasetCollectionType]
PLUGIN_CLASSES = [
list.ListDatasetCollectionType,
paired.PairedDatasetCollectionType,
record.RecordDatasetCollectionType,
]


class DatasetCollectionTypesRegistry:
Expand All @@ -14,13 +19,13 @@ def __init__(self):
def get(self, plugin_type):
return self.__plugins[plugin_type]

def prototype(self, plugin_type):
def prototype(self, plugin_type, fields=None):
plugin_type_object = self.get(plugin_type)
if not hasattr(plugin_type_object, "prototype_elements"):
raise Exception(f"Cannot pre-determine structure for collection of type {plugin_type}")

dataset_collection = model.DatasetCollection()
for e in plugin_type_object.prototype_elements():
for e in plugin_type_object.prototype_elements(fields=fields):
e.collection = dataset_collection
return dataset_collection

Expand Down
9 changes: 6 additions & 3 deletions lib/galaxy/model/dataset_collections/type_description.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ def __init__(self, type_registry=DATASET_COLLECTION_TYPES_REGISTRY):
# I think.
self.type_registry = type_registry

def for_collection_type(self, collection_type):
def for_collection_type(self, collection_type, fields=None):
assert collection_type is not None
return CollectionTypeDescription(collection_type, self)
return CollectionTypeDescription(collection_type, self, fields=fields)


class CollectionTypeDescription:
Expand Down Expand Up @@ -47,12 +47,15 @@ class CollectionTypeDescription:

collection_type: str

def __init__(self, collection_type: Union[str, "CollectionTypeDescription"], collection_type_description_factory):
def __init__(
self, collection_type: Union[str, "CollectionTypeDescription"], collection_type_description_factory, fields=None
):
if isinstance(collection_type, CollectionTypeDescription):
self.collection_type = collection_type.collection_type
else:
self.collection_type = collection_type
self.collection_type_description_factory = collection_type_description_factory
self.fields = fields
self.__has_subcollections = self.collection_type.find(":") > 0

def child_collection_type(self):
Expand Down
2 changes: 1 addition & 1 deletion lib/galaxy/model/dataset_collections/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

class DatasetCollectionType(metaclass=ABCMeta):
@abstractmethod
def generate_elements(self, dataset_instances):
def generate_elements(self, dataset_instances: dict, **kwds):
"""Generate DatasetCollectionElements with corresponding
to the supplied dataset instances or throw exception if
this is not a valid collection of the specified type.
Expand Down
4 changes: 2 additions & 2 deletions lib/galaxy/model/dataset_collections/types/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ class ListDatasetCollectionType(BaseDatasetCollectionType):

collection_type = "list"

def generate_elements(self, elements):
for identifier, element in elements.items():
def generate_elements(self, dataset_instances, **kwds):
for identifier, element in dataset_instances.items():
association = DatasetCollectionElement(
element=element,
element_identifier=identifier,
Expand Down
8 changes: 4 additions & 4 deletions lib/galaxy/model/dataset_collections/types/paired.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,21 @@ class PairedDatasetCollectionType(BaseDatasetCollectionType):

collection_type = "paired"

def generate_elements(self, elements):
if forward_dataset := elements.get(FORWARD_IDENTIFIER):
def generate_elements(self, dataset_instances, **kwds):
if forward_dataset := dataset_instances.get(FORWARD_IDENTIFIER):
left_association = DatasetCollectionElement(
element=forward_dataset,
element_identifier=FORWARD_IDENTIFIER,
)
yield left_association
if reverse_dataset := elements.get(REVERSE_IDENTIFIER):
if reverse_dataset := dataset_instances.get(REVERSE_IDENTIFIER):
right_association = DatasetCollectionElement(
element=reverse_dataset,
element_identifier=REVERSE_IDENTIFIER,
)
yield right_association

def prototype_elements(self):
def prototype_elements(self, **kwds):
left_association = DatasetCollectionElement(
element=HistoryDatasetAssociation(),
element_identifier=FORWARD_IDENTIFIER,
Expand Down
45 changes: 45 additions & 0 deletions lib/galaxy/model/dataset_collections/types/record.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from galaxy.exceptions import RequestParameterMissingException
from galaxy.model import (
DatasetCollectionElement,
HistoryDatasetAssociation,
)
from ..types import BaseDatasetCollectionType


class RecordDatasetCollectionType(BaseDatasetCollectionType):
"""Arbitrary CWL-style record type."""

collection_type = "record"

def generate_elements(self, dataset_instances, **kwds):
fields = kwds.get("fields", None)
if fields is None:
raise RequestParameterMissingException("Missing or null parameter 'fields' required for record types.")
if len(dataset_instances) != len(fields):
self._validation_failed("Supplied element do not match fields.")
index = 0
for identifier, element in dataset_instances.items():
field = fields[index]
if field["name"] != identifier:
self._validation_failed("Supplied element do not match fields.")

# TODO: validate type and such.
association = DatasetCollectionElement(
element=element,
element_identifier=identifier,
)
yield association
index += 1

def prototype_elements(self, fields=None, **kwds):
if fields is None:
raise RequestParameterMissingException("Missing or null parameter 'fields' required for record types.")
for field in fields:
name = field.get("name", None)
assert name
assert field.get("type", "File") # NS: this assert doesn't make sense as it is
field_dataset = DatasetCollectionElement(
element=HistoryDatasetAssociation(),
element_identifier=name,
)
yield field_dataset
8 changes: 7 additions & 1 deletion lib/galaxy/tool_util/parser/output_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,12 +402,14 @@ def __init__(
collection_type_from_rules: Optional[str] = None,
structured_like: Optional[str] = None,
dataset_collector_descriptions: Optional[List[DatasetCollectionDescription]] = None,
fields=None,
) -> None:
self.collection_type = collection_type
self.collection_type_source = collection_type_source
self.collection_type_from_rules = collection_type_from_rules
self.structured_like = structured_like
self.dataset_collector_descriptions = dataset_collector_descriptions or []
self.fields = fields
if collection_type and collection_type_source:
raise ValueError("Cannot set both type and type_source on collection output.")
if (
Expand All @@ -424,6 +426,10 @@ def __init__(
raise ValueError(
"Cannot specify dynamic structure (discover_datasets) and collection type attributes structured_like or collection_type_from_rules."
)
if collection_type == "record" and fields is None:
raise ValueError("If record outputs are defined, fields must be defined as well.")
if fields is not None and collection_type != "record":
raise ValueError("If fields are specified for outputs, the collection type must be record.")
self.dynamic = bool(dataset_collector_descriptions)

def collection_prototype(self, inputs, type_registry):
Expand All @@ -433,7 +439,7 @@ def collection_prototype(self, inputs, type_registry):
else:
collection_type = self.collection_type
assert collection_type
collection_prototype = type_registry.prototype(collection_type)
collection_prototype = type_registry.prototype(collection_type, fields=self.fields)
collection_prototype.collection_type = collection_type
return collection_prototype

Expand Down
5 changes: 4 additions & 1 deletion lib/galaxy/tools/actions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,10 @@ def handle_output(name, output, hidden=None):
assert not element_identifiers # known_outputs must have been empty
element_kwds = dict(elements=collections_manager.ELEMENTS_UNINITIALIZED)
else:
element_kwds = dict(element_identifiers=element_identifiers)
element_kwds = dict(
element_identifiers=element_identifiers,
fields=output.structure.fields,
)
output_collections.create_collection(
output=output, name=name, completed_job=completed_job, **element_kwds
)
Expand Down
2 changes: 1 addition & 1 deletion lib/galaxy/workflow/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,7 @@ def _find_collections_to_match(self, progress: "WorkflowProgress", step, all_inp
for input_dict in all_inputs:
name = input_dict["name"]
data = progress.replacement_for_input(self.trans, step, input_dict)
can_map_over = hasattr(data, "collection") # and data.collection.allow_implicit_mapping
can_map_over = hasattr(data, "collection") and data.collection.allow_implicit_mapping

if not can_map_over:
continue
Expand Down
Loading

0 comments on commit 5c456d1

Please sign in to comment.