diff --git a/src/neuroconv/tools/nwb_helpers/_dataset_configuration.py b/src/neuroconv/tools/nwb_helpers/_dataset_configuration.py index 4e7783aff..7ddadb7a0 100644 --- a/src/neuroconv/tools/nwb_helpers/_dataset_configuration.py +++ b/src/neuroconv/tools/nwb_helpers/_dataset_configuration.py @@ -4,7 +4,6 @@ import h5py import numpy as np import zarr -from hdmf import Container from hdmf.data_utils import DataChunkIterator, DataIO, GenericDataChunkIterator from hdmf.utils import get_data_shape from hdmf_zarr import NWBZarrIO @@ -49,91 +48,15 @@ def _is_dataset_written_to_file( ) -def _find_location_in_memory_nwbfile(current_location: str, neurodata_object: Container) -> str: - """ - Method for determining the location of a neurodata object within an in-memory NWBFile object. - - Distinct from methods from other packages, such as the NWB Inspector, which rely on such files being read from disk. - """ - parent = neurodata_object.parent - if isinstance(parent, NWBFile): - # Items in defined top-level places like acquisition, intervals, etc. do not act as 'containers' - # in that they do not set the `.parent` attribute; ask if object is in their in-memory dictionaries instead - for parent_field_name, parent_field_value in parent.fields.items(): - if isinstance(parent_field_value, dict) and neurodata_object.name in parent_field_value: - return parent_field_name + "/" + neurodata_object.name + "/" + current_location - return neurodata_object.name + "/" + current_location - return _find_location_in_memory_nwbfile( - current_location=neurodata_object.name + "/" + current_location, neurodata_object=parent - ) - - -def _infer_dtype_using_data_chunk_iterator(candidate_dataset: Union[h5py.Dataset, zarr.Array]): - """ - The DataChunkIterator has one of the best generic dtype inference, though logic is hard to peel out of it. - - It can fail in rare cases but not essential to our default configuration - """ - try: - return DataChunkIterator(candidate_dataset).dtype - except Exception as exception: - if str(exception) != "Data type could not be determined. Please specify dtype in DataChunkIterator init.": - raise exception - else: - return np.dtype("object") - - -def _get_dataset_metadata( - neurodata_object: Union[TimeSeries, DynamicTable], field_name: str, backend: Literal["hdf5", "zarr"] -) -> Union[HDF5DatasetIOConfiguration, ZarrDatasetIOConfiguration, None]: - """Fill in the Dataset model with as many values as can be automatically detected or inferred.""" - DatasetIOConfigurationClass = BACKEND_TO_DATASET_CONFIGURATION[backend] - - candidate_dataset = getattr(neurodata_object, field_name) - - # For now, skip over datasets already wrapped in DataIO - # Could maybe eventually support modifying chunks in place - # But setting buffer shape only possible if iterator was wrapped first - if isinstance(candidate_dataset, DataIO): - return None - - dtype = _infer_dtype_using_data_chunk_iterator(candidate_dataset=candidate_dataset) - full_shape = get_data_shape(data=candidate_dataset) - - if isinstance(candidate_dataset, GenericDataChunkIterator): - chunk_shape = candidate_dataset.chunk_shape - buffer_shape = candidate_dataset.buffer_shape - elif dtype != "unknown": - # TODO: eventually replace this with staticmethods on hdmf.data_utils.GenericDataChunkIterator - chunk_shape = SliceableDataChunkIterator.estimate_default_chunk_shape( - chunk_mb=10.0, maxshape=full_shape, dtype=np.dtype(dtype) - ) - buffer_shape = SliceableDataChunkIterator.estimate_default_buffer_shape( - buffer_gb=0.5, chunk_shape=chunk_shape, maxshape=full_shape, dtype=np.dtype(dtype) - ) - else: - pass # TODO: think on this; perhaps zarr's standalone estimator? - - location = _find_location_in_memory_nwbfile(current_location=field_name, neurodata_object=neurodata_object) - dataset_info = DatasetInfo( - object_id=neurodata_object.object_id, - object_name=neurodata_object.name, - location=location, - full_shape=full_shape, - dtype=dtype, - ) - dataset_configuration = DatasetIOConfigurationClass( - dataset_info=dataset_info, chunk_shape=chunk_shape, buffer_shape=buffer_shape - ) - return dataset_configuration - - def get_default_dataset_io_configurations( nwbfile: NWBFile, backend: Union[None, Literal["hdf5", "zarr"]] = None, # None for auto-detect from append mode, otherwise required ) -> Generator[DatasetIOConfiguration, None, None]: """ - Method for automatically detecting all objects in the file that could be wrapped in a DataIO. + Generate DatasetIOConfiguration objects for wrapping NWB file objects with a specific backend. + + This method automatically detects all objects in an NWB file that can be wrapped in a DataIO. It supports auto-detection + of the backend if the NWB file is in append mode, otherwise it requires a backend specification. Parameters ---------- @@ -147,6 +70,8 @@ def get_default_dataset_io_configurations( DatasetIOConfiguration A summary of each detected object that can be wrapped in a DataIO. """ + DatasetIOConfigurationClass = BACKEND_TO_DATASET_CONFIGURATION[backend] + if backend is None and nwbfile.read_io is None: raise ValueError( "Keyword argument `backend` (either 'hdf5' or 'zarr') must be specified if the `nwbfile` was not " @@ -185,7 +110,15 @@ def get_default_dataset_io_configurations( ): continue # skip - yield _get_dataset_metadata(neurodata_object=column, field_name="data", backend=backend) + # Skip over columns that are already wrapped in DataIO + if isinstance(candidate_dataset, DataIO): + continue + + dataset_io_configuration = DatasetIOConfigurationClass.from_neurodata_object( + neurodata_object=column, field_name="data" + ) + + yield dataset_io_configuration else: # Primarily for TimeSeries, but also any extended class that has 'data' or 'timestamps' # The most common example of this is ndx-events Events/LabeledEvents types @@ -201,8 +134,16 @@ def get_default_dataset_io_configurations( ): continue # skip + # Skip over datasets that are already wrapped in DataIO + if isinstance(candidate_dataset, DataIO): + continue + # Edge case of in-memory ImageSeries with external mode; data is in fields and is empty array if isinstance(candidate_dataset, np.ndarray) and candidate_dataset.size == 0: continue # skip - yield _get_dataset_metadata(neurodata_object=time_series, field_name=field_name, backend=backend) + dataset_io_configuration = DatasetIOConfigurationClass.from_neurodata_object( + neurodata_object=time_series, field_name=field_name + ) + + yield dataset_io_configuration diff --git a/src/neuroconv/tools/nwb_helpers/_models/_base_models.py b/src/neuroconv/tools/nwb_helpers/_models/_base_models.py index 8a6486e74..aef29a62e 100644 --- a/src/neuroconv/tools/nwb_helpers/_models/_base_models.py +++ b/src/neuroconv/tools/nwb_helpers/_models/_base_models.py @@ -6,8 +6,50 @@ import h5py import numcodecs import numpy as np +import zarr +from hdmf import Container from hdmf.container import DataIO +from hdmf.data_utils import DataChunkIterator, DataIO, GenericDataChunkIterator +from hdmf.utils import get_data_shape from pydantic import BaseModel, Field, root_validator +from pynwb import NWBHDF5IO, NWBFile + +from ...hdmf import SliceableDataChunkIterator + + +def _find_location_in_memory_nwbfile(current_location: str, neurodata_object: Container) -> str: + """ + Method for determining the location of a neurodata object within an in-memory NWBFile object. + + Distinct from methods from other packages, such as the NWB Inspector, which rely on such files being read from disk. + """ + parent = neurodata_object.parent + if isinstance(parent, NWBFile): + # Items in defined top-level places like acquisition, intervals, etc. do not act as 'containers' + # in that they do not set the `.parent` attribute; ask if object is in their in-memory dictionaries instead + for parent_field_name, parent_field_value in parent.fields.items(): + if isinstance(parent_field_value, dict) and neurodata_object.name in parent_field_value: + return parent_field_name + "/" + neurodata_object.name + "/" + current_location + return neurodata_object.name + "/" + current_location + return _find_location_in_memory_nwbfile( + current_location=neurodata_object.name + "/" + current_location, neurodata_object=parent + ) + + +def _infer_dtype_using_data_chunk_iterator(candidate_dataset: Union[h5py.Dataset, zarr.Array]): + """ + The DataChunkIterator has one of the best generic dtype inference, though logic is hard to peel out of it. + + It can fail in rare cases but not essential to our default configuration + """ + try: + data_type = DataChunkIterator(candidate_dataset).dtype + return data_type + except Exception as exception: + if str(exception) != "Data type could not be determined. Please specify dtype in DataChunkIterator init.": + raise exception + else: + return np.dtype("object") class DatasetInfo(BaseModel): @@ -61,6 +103,22 @@ def __init__(self, **values): values.update(dataset_name=dataset_name) super().__init__(**values) + @classmethod + def from_neurodata_object(cls, neurodata_object: Container, field_name: str) -> "DatasetInfo": + location = _find_location_in_memory_nwbfile(current_location=field_name, neurodata_object=neurodata_object) + candidate_dataset = getattr(neurodata_object, field_name) + + full_shape = get_data_shape(data=candidate_dataset) + dtype = _infer_dtype_using_data_chunk_iterator(candidate_dataset=candidate_dataset) + + return cls( + object_id=neurodata_object.object_id, + object_name=neurodata_object.name, + location=location, + full_shape=full_shape, + dtype=dtype, + ) + class DatasetIOConfiguration(BaseModel, ABC): """A data model for configuring options about an object that will become a HDF5 or Zarr Dataset in the file.""" @@ -182,6 +240,31 @@ def get_data_io_kwargs(self) -> Dict[str, Any]: """ raise NotImplementedError + @classmethod + def from_neurodata_object(cls, neurodata_object: Container, field_name: str) -> "DatasetIOConfiguration": + candidate_dataset = getattr(neurodata_object, field_name) + + dataset_info = DatasetInfo.from_neurodata_object(neurodata_object=neurodata_object, field_name=field_name) + + dtype = dataset_info.dtype + full_shape = dataset_info.full_shape + + if isinstance(candidate_dataset, GenericDataChunkIterator): + chunk_shape = candidate_dataset.chunk_shape + buffer_shape = candidate_dataset.buffer_shape + elif dtype != "unknown": + # TODO: eventually replace this with staticmethods on hdmf.data_utils.GenericDataChunkIterator + chunk_shape = SliceableDataChunkIterator.estimate_default_chunk_shape( + chunk_mb=10.0, maxshape=full_shape, dtype=np.dtype(dtype) + ) + buffer_shape = SliceableDataChunkIterator.estimate_default_buffer_shape( + buffer_gb=0.5, chunk_shape=chunk_shape, maxshape=full_shape, dtype=np.dtype(dtype) + ) + else: + pass + + return cls(dataset_info=dataset_info, chunk_shape=chunk_shape, buffer_shape=buffer_shape) + class BackendConfiguration(BaseModel): """A model for matching collections of DatasetConfigurations to a specific backend."""