From 5efe5e33f09a2f870a271bd6283d9248f859a2a5 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Fri, 24 May 2024 19:31:31 -0600 Subject: [PATCH] Unify `add_units` and `add_sorting` signature (#875) Co-authored-by: Cody Baker <51133164+CodyCBakerPhD@users.noreply.github.com> --- CHANGELOG.md | 1 + .../tools/spikeinterface/spikeinterface.py | 96 +++++++++++++++---- 2 files changed, 77 insertions(+), 20 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5b5d32850..64ef835eb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -39,6 +39,7 @@ * Support for pathlib in source data schema validation. [PR #854](https://github.com/catalystneuro/neuroconv/pull/854) * Use `ZoneInfo` instead of `dateutil.tz` in the conversion gallery. [PR #858](https://github.com/catalystneuro/neuroconv/pull/858) * Exposed `progress_bar_class` to ecephys and ophys data iterators. [PR #861](https://github.com/catalystneuro/neuroconv/pull/861) +* Unified the signatures between `add_units`, `add_sorting` and `write_sorting` [PR #875](https://github.com/catalystneuro/neuroconv/pull/875) ### Testing * Add general test for metadata in-place modification by interfaces. [PR #815](https://github.com/catalystneuro/neuroconv/pull/815) diff --git a/src/neuroconv/tools/spikeinterface/spikeinterface.py b/src/neuroconv/tools/spikeinterface/spikeinterface.py index 000737357..ec9caa711 100644 --- a/src/neuroconv/tools/spikeinterface/spikeinterface.py +++ b/src/neuroconv/tools/spikeinterface/spikeinterface.py @@ -957,36 +957,39 @@ def add_units_table( unit_electrode_indices=None, ): """ - Primary method for writing a SortingExtractor object to an NWBFile. + Add sorting data to a NWBFile object as a Units table. + + This function extracts unit properties from a SortingExtractor object and writes them + to an NWBFile Units table, either in the primary units interface or the processing + module (for intermediate/historical data). It handles unit selection, property customization, + waveform data, and electrode mapping. Parameters ---------- sorting : spikeinterface.BaseSorting - nwbfile : NWBFile - unit_ids : list of int or list of str, optional - Controls the unit_ids that will be written to the nwb file. If None, all - units are written. + The SortingExtractor object containing unit data. + nwbfile : pynwb.NWBFile + The NWBFile object to write the unit data into. + unit_ids : list of int or str, optional + The specific unit IDs to write. If None, all units are written. property_descriptions : dict, optional - For each key in this dictionary which matches the name of a unit - property in sorting, adds the value as a description to that - custom unit column. + Custom descriptions for unit properties. Keys should match property names in `sorting`, + and values will be used as descriptions in the Units table. skip_properties : list of str, optional - Each string in this list that matches a unit property will not be written to the NWBFile. - write_in_processing_module : bool, default: False - How to save the units table in the nwb file. - - True will save it to the processing module to serve as a historical provenance for the official table. - - False will save it to the official NWBFile.Units position; recommended only for the final form of the data. + Unit properties to exclude from writing. units_table_name : str, default: 'units' - The name of the units table. If write_as=='units', then units_table_name must also be 'units'. + Name of the Units table. Must be 'units' if `write_in_processing_module` is False. unit_table_description : str, optional - Text description of the units table; it is recommended to include information such as the sorting method, - curation steps, etc. + Description for the Units table (e.g., sorting method, curation details). + write_in_processing_module : bool, default: False + If True, write to the processing module (intermediate data). If False, write to + the primary NWBFile.units table. waveform_means : np.ndarray, optional - Waveform mean (template) for each unit (num_units, num_samples, num_channels) + Waveform mean (template) for each unit. Shape: (num_units, num_samples, num_channels). waveform_sds : np.ndarray, optional - Waveform standard deviation for each unit (num_units, num_samples, num_channels) - unit_electrode_indices : list of lists or arrays, optional - For each unit, the indices of electrodes that each waveform_mean/sd correspond to. + Waveform standard deviation for each unit. Shape: (num_units, num_samples, num_channels). + unit_electrode_indices : list of lists of int, optional + For each unit, a list of electrode indices corresponding to waveform data. """ if not write_in_processing_module and units_table_name != "units": raise ValueError("When writing to the nwbfile.units table, the name of the table must be 'units'!") @@ -1188,7 +1191,45 @@ def add_sorting( write_as: Literal["units", "processing"] = "units", units_name: str = "units", units_description: str = "Autogenerated by neuroconv.", + waveform_means: Optional[np.ndarray] = None, + waveform_sds: Optional[np.ndarray] = None, + unit_electrode_indices=None, ): + """Add sorting data (units and their properties) to an NWBFile. + + This function serves as a convenient wrapper around `add_units_table` to match + Spikeinterface's `SortingExtractor` + + Parameters + ---------- + sorting : BaseSorting + The SortingExtractor object containing unit data. + nwbfile : pynwb.NWBFile, optional + The NWBFile object to write the unit data into. + unit_ids : list of int or str, optional + The specific unit IDs to write. If None, all units are written. + property_descriptions : dict, optional + Custom descriptions for unit properties. Keys should match property names in `sorting`, + and values will be used as descriptions in the Units table. + skip_properties : list of str, optional + Unit properties to exclude from writing. + skip_features : list of str, optional + Deprecated argument (to be removed). Previously used to skip spike features. + write_as : {'units', 'processing'}, default: 'units' + Where to write the unit data: + - 'units': Write to the primary NWBFile.units table. + - 'processing': Write to the processing module (intermediate data). + units_name : str, default: 'units' + Name of the Units table. Must be 'units' if `write_as` is 'units'. + units_description : str, optional + Description for the Units table (e.g., sorting method, curation details). + waveform_means : np.ndarray, optional + Waveform mean (template) for each unit. Shape: (num_units, num_samples, num_channels). + waveform_sds : np.ndarray, optional + Waveform standard deviation for each unit. Shape: (num_units, num_samples, num_channels). + unit_electrode_indices : list of lists of int, optional + For each unit, a list of electrode indices corresponding to waveform data. + """ if skip_features is not None: warnings.warn( @@ -1212,6 +1253,9 @@ def add_sorting( write_in_processing_module=write_in_processing_module, units_table_name=units_name, unit_table_description=units_description, + waveform_means=waveform_means, + waveform_sds=waveform_sds, + unit_electrode_indices=unit_electrode_indices, ) @@ -1229,6 +1273,9 @@ def write_sorting( write_as: Literal["units", "processing"] = "units", units_name: str = "units", units_description: str = "Autogenerated by neuroconv.", + waveform_means: Optional[np.ndarray] = None, + waveform_sds: Optional[np.ndarray] = None, + unit_electrode_indices=None, ): """ Primary method for writing a SortingExtractor object to an NWBFile. @@ -1273,6 +1320,12 @@ def write_sorting( units_name : str, default: 'units' The name of the units table. If write_as=='units', then units_name must also be 'units'. units_description : str, default: 'Autogenerated by neuroconv.' + waveform_means : np.ndarray, optional + Waveform mean (template) for each unit. Shape: (num_units, num_samples, num_channels). + waveform_sds : np.ndarray, optional + Waveform standard deviation for each unit. Shape: (num_units, num_samples, num_channels). + unit_electrode_indices : list of lists of int, optional + For each unit, a list of electrode indices corresponding to waveform data. """ with make_or_load_nwbfile( @@ -1288,6 +1341,9 @@ def write_sorting( write_as=write_as, units_name=units_name, units_description=units_description, + waveform_means=waveform_means, + waveform_sds=waveform_sds, + unit_electrode_indices=unit_electrode_indices, )