Skip to content

Commit

Permalink
WIP: sample sheets...
Browse files Browse the repository at this point in the history
  • Loading branch information
jmchilton committed Dec 10, 2024
1 parent 2982f2c commit fc02127
Show file tree
Hide file tree
Showing 18 changed files with 603 additions and 14 deletions.
17 changes: 16 additions & 1 deletion lib/galaxy/managers/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ def create(
completed_job=None,
output_name=None,
fields=None,
column_definitions=None,
rows=None,
):
"""
PRECONDITION: security checks on ability to add to parent
Expand All @@ -201,6 +203,8 @@ def create(
copy_elements=copy_elements,
history=history,
fields=fields,
column_definitions=column_definitions,
rows=rows,
)

implicit_inputs = []
Expand Down Expand Up @@ -288,6 +292,8 @@ def create_dataset_collection(
copy_elements=False,
history=None,
fields=None,
column_definitions=None,
rows=None,
):
# Make sure at least one of these is None.
assert element_identifiers is None or elements is None
Expand Down Expand Up @@ -324,9 +330,12 @@ def create_dataset_collection(

if elements is not self.ELEMENTS_UNINITIALIZED:
type_plugin = collection_type_description.rank_type_plugin()
dataset_collection = builder.build_collection(type_plugin, elements, fields=fields)
dataset_collection = builder.build_collection(
type_plugin, elements, fields=fields, column_definitions=column_definitions, rows=rows
)
else:
# TODO: Pass fields here - need test case first.
# TODO: same with column definitions I think.
dataset_collection = model.DatasetCollection(populated=False)
dataset_collection.collection_type = collection_type
return dataset_collection
Expand Down Expand Up @@ -783,10 +792,16 @@ def __init_rule_data(self, elements, collection_type_description, parent_identif
identifiers = parent_identifiers + [element.element_identifier]
if not element.is_collection:
data.append([])
columns = None
collection_type_str = collection_type_description.collection_type
if collection_type_str == "sample_sheet":
columns = element.columns
assert isinstance(columns, list)
source = {
"identifiers": identifiers,
"dataset": element_object,
"tags": element_object.make_tag_string_list(),
"columns": columns,
}
sources.append(source)
else:
Expand Down
6 changes: 6 additions & 0 deletions lib/galaxy/managers/collections_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
exceptions,
model,
)
from galaxy.model.dataset_collections.types.sample_sheet_util import validate_column_definitions
from galaxy.util import string_as_bool

log = logging.getLogger(__name__)
Expand All @@ -33,13 +34,18 @@ def api_payload_to_create_params(payload):
message = f"Missing required parameters {missing_parameters}"
raise exceptions.ObjectAttributeMissingException(message)

column_definitions = payload.get("column_definitions", None)
validate_column_definitions(column_definitions)

params = dict(
collection_type=payload.get("collection_type"),
element_identifiers=payload.get("element_identifiers"),
name=payload.get("name", None),
hide_source_items=string_as_bool(payload.get("hide_source_items", False)),
copy_elements=string_as_bool(payload.get("copy_elements", False)),
fields=payload.get("fields", None),
column_definitions=column_definitions,
rows=payload.get("rows", None),
)
return params

Expand Down
16 changes: 14 additions & 2 deletions lib/galaxy/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@
DatasetValidatedState,
InvocationsStateCounts,
JobState,
SampleSheetColumnDefinitions,
SampleSheetRow,
ToolRequestState,
)
from galaxy.schema.workflow.comments import WorkflowCommentModel
Expand Down Expand Up @@ -260,6 +262,7 @@ class ConfigurationTemplateEnvironmentVariable(TypedDict):
CONFIGURATION_TEMPLATE_CONFIGURATION_VARIABLES_TYPE = Dict[str, CONFIGURATION_TEMPLATE_CONFIGURATION_VALUE_TYPE]
CONFIGURATION_TEMPLATE_CONFIGURATION_SECRET_NAMES_TYPE = List[str]
CONFIGURATION_TEMPLATE_DEFINITION_TYPE = Dict[str, Any]
DATA_COLLECTION_FIELDS = List[Dict[str, Any]]


class TransformAction(TypedDict):
Expand Down Expand Up @@ -6521,6 +6524,10 @@ class DatasetCollection(Base, Dictifiable, UsesAnnotations, Serializable):
element_count: Mapped[Optional[int]]
create_time: Mapped[datetime] = mapped_column(default=now, nullable=True)
update_time: Mapped[datetime] = mapped_column(default=now, onupdate=now, nullable=True)
# if collection_type is 'record' (heterogenous collection)
fields: Mapped[Optional[DATA_COLLECTION_FIELDS]] = mapped_column(JSONType)
# if collection_type is 'sample_sheet' (collection of rows that datasets with extra column metadata)
column_definitions: Mapped[Optional[SampleSheetColumnDefinitions]] = mapped_column(JSONType)

elements: Mapped[List["DatasetCollectionElement"]] = relationship(
primaryjoin=(lambda: DatasetCollection.id == DatasetCollectionElement.dataset_collection_id),
Expand All @@ -6540,14 +6547,15 @@ def __init__(
populated=True,
element_count=None,
fields=None,
column_definitions=None,
):
self.id = id
self.collection_type = collection_type
if not populated:
self.populated_state = DatasetCollection.populated_states.NEW
self.element_count = element_count
# TODO: persist fields...
self.fields = fields
self.column_definitions = column_definitions

def _build_nested_collection_attributes_stmt(
self,
Expand Down Expand Up @@ -6956,6 +6964,7 @@ def _base_to_dict(self, view):
name=self.name,
collection_id=self.collection_id,
collection_type=self.collection.collection_type,
column_definitions=self.collection.column_definitions,
populated=self.populated,
populated_state=self.collection.populated_state,
populated_state_message=self.collection.populated_state_message,
Expand Down Expand Up @@ -7443,6 +7452,7 @@ class DatasetCollectionElement(Base, Dictifiable, Serializable):
# Element index and identifier to define this parent-child relationship.
element_index: Mapped[Optional[int]]
element_identifier: Mapped[Optional[str]] = mapped_column(Unicode(255))
columns: Mapped[Optional[SampleSheetRow]] = mapped_column(JSONType)

hda = relationship(
"HistoryDatasetAssociation",
Expand All @@ -7463,7 +7473,7 @@ class DatasetCollectionElement(Base, Dictifiable, Serializable):

# actionable dataset id needs to be available via API...
dict_collection_visible_keys = ["id", "element_type", "element_index", "element_identifier"]
dict_element_visible_keys = ["id", "element_type", "element_index", "element_identifier"]
dict_element_visible_keys = ["id", "element_type", "element_index", "element_identifier", "columns"]

UNINITIALIZED_ELEMENT = object()

Expand All @@ -7474,6 +7484,7 @@ def __init__(
element=None,
element_index=None,
element_identifier=None,
columns: Optional[SampleSheetRow] = None,
):
if isinstance(element, HistoryDatasetAssociation):
self.hda = element
Expand All @@ -7489,6 +7500,7 @@ def __init__(
self.collection = collection
self.element_index = element_index
self.element_identifier = element_identifier or str(element_index)
self.columns = columns

def __strict_check_before_flush__(self):
if self.collection.populated_optimized:
Expand Down
25 changes: 20 additions & 5 deletions lib/galaxy/model/dataset_collections/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,42 @@
from .type_description import COLLECTION_TYPE_DESCRIPTION_FACTORY


def build_collection(type, dataset_instances, collection=None, associated_identifiers=None, fields=None):
def build_collection(
type,
dataset_instances,
collection=None,
associated_identifiers=None,
fields=None,
column_definitions=None,
rows=None,
):
"""
Build DatasetCollection with populated DatasetcollectionElement objects
corresponding to the supplied dataset instances or throw exception if
this is not a valid collection of the specified type.
"""
dataset_collection = collection or model.DatasetCollection(fields=fields)
dataset_collection = collection or model.DatasetCollection(fields=fields, column_definitions=column_definitions)
associated_identifiers = associated_identifiers or set()
set_collection_elements(dataset_collection, type, dataset_instances, associated_identifiers, fields=fields)
set_collection_elements(
dataset_collection, type, dataset_instances, associated_identifiers, fields=fields, rows=rows
)
return dataset_collection


def set_collection_elements(dataset_collection, type, dataset_instances, associated_identifiers, fields=None):
def set_collection_elements(
dataset_collection, type, dataset_instances, associated_identifiers, fields=None, rows=None
):
new_element_keys = OrderedSet(dataset_instances.keys()) - associated_identifiers
new_dataset_instances = {k: dataset_instances[k] for k in new_element_keys}
dataset_collection.element_count = dataset_collection.element_count or 0
element_index = dataset_collection.element_count
elements = []
if fields == "auto":
fields = guess_fields(dataset_instances)
for element in type.generate_elements(new_dataset_instances, fields=fields):
column_definitions = dataset_collection.column_definitions
for element in type.generate_elements(
new_dataset_instances, fields=fields, rows=rows, column_definitions=column_definitions
):
element.element_index = element_index
add_object_to_object_session(element, dataset_collection)
element.collection = dataset_collection
Expand Down
2 changes: 2 additions & 0 deletions lib/galaxy/model/dataset_collections/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
list,
paired,
record,
sample_sheet,
)

PLUGIN_CLASSES = [
list.ListDatasetCollectionType,
paired.PairedDatasetCollectionType,
record.RecordDatasetCollectionType,
sample_sheet.SampleSheetDatasetCollectionType,
]


Expand Down
30 changes: 30 additions & 0 deletions lib/galaxy/model/dataset_collections/types/sample_sheet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from galaxy.exceptions import RequestParameterMissingException
from galaxy.model import DatasetCollectionElement
from . import BaseDatasetCollectionType
from .sample_sheet_util import validate_row


class SampleSheetDatasetCollectionType(BaseDatasetCollectionType):
"""A flat list of named elements starting rows with column metadata."""

collection_type = "sample_sheet"

def generate_elements(self, dataset_instances, **kwds):
rows = kwds.get("rows", None)
column_definitions = kwds.get("column_definitions", None)
if rows is None:
raise RequestParameterMissingException(
"Missing or null parameter 'rows' required for 'sample_sheet' collection types."
)
if len(dataset_instances) != len(rows):
self._validation_failed("Supplied element do not match 'rows'.")

for identifier, element in dataset_instances.items():
columns = rows[identifier]
validate_row(columns, column_definitions)
association = DatasetCollectionElement(
element=element,
element_identifier=identifier,
columns=columns,
)
yield association
101 changes: 101 additions & 0 deletions lib/galaxy/model/dataset_collections/types/sample_sheet_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from typing import (
List,
Optional,
Union,
)

from pydantic import (
BaseModel,
ConfigDict,
RootModel,
TypeAdapter,
)

from galaxy.exceptions import RequestParameterInvalidException
from galaxy.schema.schema import (
SampleSheetColumnDefinition,
SampleSheetColumnDefinitions,
SampleSheetColumnType,
SampleSheetColumnValueT,
SampleSheetRow,
)
from galaxy.tool_util.parser.parameter_validators import (
AnySafeValidatorModel,
DiscriminatedAnySafeValidatorModel,
parse_dict_validators,
UnsafeValidatorConfiguredInUntrustedContext,
)


class SampleSheetColumnDefinitionModel(BaseModel):
model_config = ConfigDict(extra="forbid")
type: SampleSheetColumnType
validators: Optional[List[AnySafeValidatorModel]] = None
restrictions: Optional[List[SampleSheetColumnValueT]] = None
suggestions: Optional[List[SampleSheetColumnValueT]] = None


SampleSheetColumnDefinitionsModel = RootModel[List[SampleSheetColumnDefinitionModel]]
SampleSheetColumnDefinitionDictOrModel = Union[SampleSheetColumnDefinition, SampleSheetColumnDefinitionModel]


def sample_sheet_column_definition_to_model(
column_definition: SampleSheetColumnDefinitionDictOrModel,
) -> SampleSheetColumnDefinitionModel:
if isinstance(column_definition, SampleSheetColumnDefinitionModel):
return column_definition
else:
return SampleSheetColumnDefinitionModel.model_validate(column_definition)


def validate_column_definitions(column_definitions: Optional[SampleSheetColumnDefinitions]):
for column_definition in column_definitions or []:
_validate_column_definition(column_definition)


def _validate_column_definition(column_definition: SampleSheetColumnDefinition):
# we should do most of this with pydantic but I just wanted to especially make sure
# we were only using safe validators
return SampleSheetColumnDefinitionModel(**column_definition)


def validate_row(row: SampleSheetRow, column_definitions: Optional[SampleSheetColumnDefinitions]):
if column_definitions is None:
return
if len(row) != len(column_definitions):
raise RequestParameterInvalidException(
"Sample sheet row validation failed, incorrect number of columns specified."
)
for column_value, column_definition in zip(row, column_definitions):
validate_column_value(column_value, column_definition)


def validate_column_value(
column_value: SampleSheetColumnValueT, column_definition: SampleSheetColumnDefinitionDictOrModel
):
column_definition_model = sample_sheet_column_definition_to_model(column_definition)
column_type = column_definition_model.type
if column_type == "int":
if not isinstance(column_value, int):
raise RequestParameterInvalidException(f"{column_value} was not an integer as expected")
elif column_type == "float":
if not isinstance(column_value, (float, int)):
raise RequestParameterInvalidException(f"{column_value} was not a number as expected")
elif column_type == "string":
if not isinstance(column_value, (str,)):
raise RequestParameterInvalidException(f"{column_value} was not a string as expected")
elif column_type == "boolean":
if not isinstance(column_value, (bool,)):
raise RequestParameterInvalidException(f"{column_value} was not a boolean as expected")
restrictions = column_definition_model.restrictions
if restrictions is not None:
if column_value not in restrictions:
raise RequestParameterInvalidException(
f"{column_value} was not in specified list of valid values as expected"
)
validators = column_definition_model.validators or []
for validator in validators:
try:
validator.statically_validate(column_value)
except ValueError as e:
raise RequestParameterInvalidException(str(e))
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,12 @@
def upgrade():
with transaction():
add_column(dataset_collection_table, Column("column_definitions", JSONType(), default=None))
add_column(dataset_collection_table, Column("fields", JSONType(), default=None))
add_column(dataset_collection_element_table, Column("columns", JSONType(), default=None))


def downgrade():
with transaction():
drop_column(dataset_collection_table, "column_definitions")
drop_column(dataset_collection_table, "fields")
drop_column(dataset_collection_element_table, "columns")
Loading

0 comments on commit fc02127

Please sign in to comment.