From 40d786af408e34efaea66edb387fde3d4b8cf32f Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 14 Aug 2024 19:21:47 -0300 Subject: [PATCH] Drop casting to int in units table and introduce fine control over null values in `units_table` (#989) Co-authored-by: Cody Baker <51133164+CodyCBakerPhD@users.noreply.github.com> Co-authored-by: Ben Dichter --- CHANGELOG.md | 4 + .../tools/spikeinterface/__init__.py | 3 + .../tools/spikeinterface/spikeinterface.py | 121 ++++++++++-------- .../test_ecephys/test_tools_spikeinterface.py | 95 +++++++++++--- 4 files changed, 154 insertions(+), 69 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 81e5db922..b7a6ff48d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,10 +11,14 @@ * Add Plexon2 support [PR #918](https://github.com/catalystneuro/neuroconv/pull/918) * Converter working with multiple VideoInterface instances [PR #914](https://github.com/catalystneuro/neuroconv/pull/914) * Added helper function `neuroconv.tools.data_transfers.submit_aws_batch_job` for basic automated submission of AWS batch jobs. [PR #384](https://github.com/catalystneuro/neuroconv/pull/384) +* Introduced `null_values_for_properties` to `add_units_table` to give user control over null values behavior [PR #989](https://github.com/catalystneuro/neuroconv/pull/989) + ### Bug fixes * Fixed the default naming of multiple electrical series in the `SpikeGLXConverterPipe`. [PR #957](https://github.com/catalystneuro/neuroconv/pull/957) * Write new properties to the electrode table use the global identifier channel_name, group [PR #984](https://github.com/catalystneuro/neuroconv/pull/984) +* Removed a bug where int64 was casted lossy to float [PR #989](https://github.com/catalystneuro/neuroconv/pull/989) + ### Improvements * The `OpenEphysBinaryRecordingInterface` now uses `lxml` for extracting the session start time from the settings.xml file and does not depend on `pyopenephys` anymore. [PR #971](https://github.com/catalystneuro/neuroconv/pull/971) diff --git a/src/neuroconv/tools/spikeinterface/__init__.py b/src/neuroconv/tools/spikeinterface/__init__.py index 2a796f492..165b76fc8 100644 --- a/src/neuroconv/tools/spikeinterface/__init__.py +++ b/src/neuroconv/tools/spikeinterface/__init__.py @@ -6,9 +6,12 @@ add_recording, add_sorting, add_units_table, + add_sorting_analyzer, check_if_recording_traces_fit_into_memory, get_nwb_metadata, write_recording, write_sorting, write_sorting_analyzer, + add_waveforms, + write_waveforms, ) diff --git a/src/neuroconv/tools/spikeinterface/spikeinterface.py b/src/neuroconv/tools/spikeinterface/spikeinterface.py index b31ebe2dd..246874f8b 100644 --- a/src/neuroconv/tools/spikeinterface/spikeinterface.py +++ b/src/neuroconv/tools/spikeinterface/spikeinterface.py @@ -1,7 +1,6 @@ import uuid import warnings from collections import defaultdict -from numbers import Real from typing import Any, List, Literal, Optional, Union import numpy as np @@ -399,8 +398,6 @@ def add_electrodes( exclude: tuple An iterable containing the string names of channel properties in the RecordingExtractor object to ignore when writing to the NWBFile. - null_values_for_properties: dict - A dictionary mapping channel properties to null values to use when the property is not present """ assert isinstance( nwbfile, pynwb.NWBFile @@ -807,7 +804,6 @@ def add_electrical_series( # Create a region for the electrodes table electrode_table_indices = _get_electrode_table_indices_for_recording(recording=recording, nwbfile=nwbfile) - electrode_table_indices = _get_electrode_table_indices_for_recording(recording=recording, nwbfile=nwbfile) electrode_table_region = nwbfile.create_electrode_table_region( region=electrode_table_indices, description="electrode_table_region", @@ -1107,6 +1103,7 @@ def add_units_table( waveform_means: Optional[np.ndarray] = None, waveform_sds: Optional[np.ndarray] = None, unit_electrode_indices=None, + null_values_for_properties: Optional[dict] = None, ): """ Add sorting data to a NWBFile object as a Units table. @@ -1142,12 +1139,16 @@ def add_units_table( Waveform standard deviation for each unit. Shape: (num_units, num_samples, num_channels). unit_electrode_indices : list of lists of int, optional For each unit, a list of electrode indices corresponding to waveform data. + null_values_for_properties: dict, optional + A dictionary mapping properties to null values to use when the property is not present """ assert isinstance( nwbfile, pynwb.NWBFile ), f"'nwbfile' should be of type pynwb.NWBFile but is of type {type(nwbfile)}" + null_values_for_properties = dict() if null_values_for_properties is None else null_values_for_properties + if not write_in_processing_module and units_table_name != "units": raise ValueError("When writing to the nwbfile.units table, the name of the table must be 'units'!") @@ -1203,8 +1204,7 @@ def add_units_table( # Extract properties for property in properties_to_extract: data = sorting.get_property(property) - if isinstance(data[0], (bool, np.bool_)): - data = data.astype(str) + index = isinstance(data[0], (list, np.ndarray, tuple)) if index and isinstance(data[0], np.ndarray): index = data[0].ndim @@ -1222,20 +1222,25 @@ def add_units_table( unit_name_array = unit_ids.astype("str", copy=False) data_to_add["unit_name"].update(description="Unique reference for each unit.", data=unit_name_array) - units_table_previous_properties = set(units_table.colnames) - {"spike_times"} - extracted_properties = set(data_to_add) - properties_to_add_by_rows = units_table_previous_properties | {"id"} - properties_to_add_by_columns = extracted_properties - properties_to_add_by_rows - - # Find default values for properties / columns already in the table - type_to_default_value = {list: [], np.ndarray: np.array(np.nan), str: "", Real: np.nan} - property_to_default_values = {"id": None} - for property in units_table_previous_properties: - # Find a matching data type and get the default value - sample_data = units_table[property].data[0] - matching_type = next(type for type in type_to_default_value if isinstance(sample_data, type)) - default_value = type_to_default_value[matching_type] - property_to_default_values.update({property: default_value}) + units_table_previous_properties = set(units_table.colnames).difference({"spike_times"}) + properties_to_add = set(data_to_add) + properties_to_add_by_rows = units_table_previous_properties.union({"id"}) + properties_to_add_by_columns = properties_to_add - properties_to_add_by_rows + + # Properties that were added before require null values to add by rows if data is missing + properties_requiring_null_values = units_table_previous_properties.difference(properties_to_add) + null_values_for_row = {} + for property in properties_requiring_null_values - {"electrodes"}: # TODO, fix electrodes + sample_data = units_table[property][:][0] + null_value = _get_null_value_for_property( + property=property, + sample_data=sample_data, + null_values_for_properties=null_values_for_properties, + ) + null_values_for_row[property] = null_value + + # Special case + null_values_for_row["id"] = None # Add data by rows excluding the rows with previously added unit names unit_names_used_previously = [] @@ -1259,7 +1264,7 @@ def add_units_table( rows_to_add.append(index) for row in rows_to_add: - unit_kwargs = dict(property_to_default_values) + unit_kwargs = null_values_for_row for property in properties_with_data: unit_kwargs[property] = data_to_add[property]["data"][row] spike_times = [] @@ -1278,9 +1283,9 @@ def add_units_table( if unit_electrode_indices is not None: unit_kwargs["electrodes"] = unit_electrode_indices[row] units_table.add_unit(spike_times=spike_times, **unit_kwargs, enforce_unique_id=True) - # added_unit_table_ids = units_table.id[-len(rows_to_add) :] # TODO - this line is unused? # Add unit_name as a column and fill previously existing rows with unit_name equal to str(ids) + unit_table_size = len(units_table.id[:]) previous_table_size = len(units_table.id[:]) - len(unit_name_array) if "unit_name" in properties_to_add_by_columns: cols_args = data_to_add["unit_name"] @@ -1299,41 +1304,47 @@ def add_units_table( unit_name: table_df.query(f"unit_name=='{unit_name}'").index[0] for unit_name in unit_name_array } - indexes_for_new_data = [unit_name_to_electrode_index[unit_name] for unit_name in unit_name_array] - indexes_for_default_values = table_df.index.difference(indexes_for_new_data).values + indices_for_new_data = [unit_name_to_electrode_index[unit_name] for unit_name in unit_name_array] + indices_for_null_values = table_df.index.difference(indices_for_new_data).values + extending_column = len(indices_for_null_values) > 0 # Add properties as columns for property in properties_to_add_by_columns - {"unit_name"}: cols_args = data_to_add[property] data = cols_args["data"] - if np.issubdtype(data.dtype, np.integer): - data = data.astype("float") - # Find first matching data-type - sample_data = data[0] - matching_type = next(type for type in type_to_default_value if isinstance(sample_data, type)) - default_value = type_to_default_value[matching_type] + # This is the simple case, early return + if not extending_column: + units_table.add_column(property, **cols_args) + continue + + # Extending the columns is done differently for ragged arrays + adding_ragged_array = cols_args["index"] + if not adding_ragged_array: + sample_data = data[0] + dtype = data.dtype + extended_data = np.empty(shape=unit_table_size, dtype=dtype) + extended_data[indices_for_new_data] = data + + null_value = _get_null_value_for_property( + property=property, + sample_data=sample_data, + null_values_for_properties=null_values_for_properties, + ) + extended_data[indices_for_null_values] = null_value + else: - if "index" in cols_args and cols_args["index"]: dtype = np.ndarray - extended_data = np.empty(shape=len(units_table.id[:]), dtype=dtype) + extended_data = np.empty(shape=unit_table_size, dtype=dtype) for index, value in enumerate(data): - index_in_extended_data = indexes_for_new_data[index] + index_in_extended_data = indices_for_new_data[index] extended_data[index_in_extended_data] = value.tolist() - for index in indexes_for_default_values: - default_value = [] - extended_data[index] = default_value - - else: - dtype = data.dtype - extended_data = np.empty(shape=len(units_table.id[:]), dtype=dtype) - extended_data[indexes_for_new_data] = data - extended_data[indexes_for_default_values] = default_value - - if np.issubdtype(extended_data.dtype, np.object_): - extended_data = extended_data.astype("str", copy=False) + for index in indices_for_null_values: + null_value = [] + extended_data[index] = null_value + # Add the data cols_args["data"] = extended_data units_table.add_column(property, **cols_args) @@ -1561,6 +1572,7 @@ def add_sorting_analyzer( The name of the units table. If write_as=='units', then units_name must also be 'units'. units_description : str, default: 'Autogenerated by neuroconv.' """ + # TODO: move into add_units assert write_as in [ "units", @@ -1669,7 +1681,7 @@ def write_sorting_analyzer( Controls the unit_ids that will be written to the nwb file. If None (default), all units are written. write_electrical_series : bool, default: False - If True, the recording object associated to t is written as an electrical series. + If True, the recording object associated to the analyzer is written as an electrical series. add_electrical_series_kwargs: dict, optional Keyword arguments to control the `add_electrical_series()` function in case write_electrical_series=True property_descriptions: dict, optional @@ -1688,16 +1700,17 @@ def write_sorting_analyzer( """ metadata = metadata if metadata is not None else dict() + if sorting_analyzer.has_recording(): + recording = sorting_analyzer.recording + assert recording is not None, ( + "recording not found. To add the electrode table, the sorting_analyzer " + "needs to have a recording attached or the 'recording' argument needs to be used." + ) + # try: with make_or_load_nwbfile( nwbfile_path=nwbfile_path, nwbfile=nwbfile, metadata=metadata, overwrite=overwrite, verbose=verbose ) as nwbfile_out: - if sorting_analyzer.has_recording(): - recording = sorting_analyzer.recording - assert recording is not None, ( - "recording not found. To add the electrode table, the sorting_analyzer " - "needs to have a recording attached or the 'recording' argument needs to be used." - ) if write_electrical_series: add_electrical_series_kwargs = add_electrical_series_kwargs or dict() @@ -1719,6 +1732,7 @@ def write_sorting_analyzer( ) +# TODO: Remove February 2025 def write_waveforms( waveform_extractor, nwbfile_path: Optional[FilePathType] = None, @@ -1746,6 +1760,9 @@ def write_waveforms( ) +# TODO: Remove February 2025 + + def add_waveforms( waveform_extractor, nwbfile: Optional[pynwb.NWBFile] = None, diff --git a/tests/test_ecephys/test_tools_spikeinterface.py b/tests/test_ecephys/test_tools_spikeinterface.py index 7b377a694..ec3702fcd 100644 --- a/tests/test_ecephys/test_tools_spikeinterface.py +++ b/tests/test_ecephys/test_tools_spikeinterface.py @@ -10,7 +10,7 @@ import pynwb.ecephys from hdmf.data_utils import DataChunkIterator from hdmf.testing import TestCase -from pynwb import NWBHDF5IO, NWBFile +from pynwb import NWBFile from spikeinterface.core.generate import ( generate_ground_truth_recording, generate_recording, @@ -18,7 +18,7 @@ ) from spikeinterface.extractors import NumpyRecording -from neuroconv.tools.nwb_helpers import get_default_nwbfile_metadata, get_module +from neuroconv.tools.nwb_helpers import get_module from neuroconv.tools.spikeinterface import ( add_electrical_series, add_electrodes, @@ -1263,7 +1263,7 @@ def test_write_bool_properties(self): nwbfile=self.nwbfile, ) self.assertIn("test_bool", self.nwbfile.units.colnames) - assert all(tb in ["False", "True"] for tb in self.nwbfile.units["test_bool"][:]) + assert all(tb in [False, True] for tb in self.nwbfile.units["test_bool"][:]) def test_adding_ragged_array_properties(self): @@ -1338,6 +1338,65 @@ def test_adding_doubled_ragged_arrays(self): for i, value in enumerate(written_values): np.testing.assert_array_equal(value, expected_values[i]) + def test_missing_int_values(self): + + sorting1 = generate_sorting(num_units=2, durations=[1.0]) + sorting1 = sorting1.rename_units(new_unit_ids=["a", "b"]) + sorting1.set_property(key="complete_int_property", values=[1, 2]) + add_units_table(sorting=sorting1, nwbfile=self.nwbfile) + + expected_property = np.asarray([1, 2]) + extracted_property = self.nwbfile.units["complete_int_property"].data + assert np.array_equal(extracted_property, expected_property) + + sorting2 = generate_sorting(num_units=2, durations=[1.0]) + sorting2 = sorting2.rename_units(new_unit_ids=["c", "d"]) + + sorting2.set_property(key="incomplete_int_property", values=[10, 11]) + with self.assertRaises(ValueError): + add_units_table(sorting=sorting2, nwbfile=self.nwbfile) + + null_values_for_properties = {"complete_int_property": -1, "incomplete_int_property": -3} + add_units_table(sorting=sorting2, nwbfile=self.nwbfile, null_values_for_properties=null_values_for_properties) + + expected_complete_property = np.asarray([1, 2, -1, -1]) + expected_incomplete_property = np.asarray([-3, -3, 10, 11]) + + extracted_complete_property = self.nwbfile.units["complete_int_property"].data + extracted_incomplete_property = self.nwbfile.units["incomplete_int_property"].data + + assert np.array_equal(extracted_complete_property, expected_complete_property) + assert np.array_equal(extracted_incomplete_property, expected_incomplete_property) + + def test_missing_bool_values(self): + sorting1 = generate_sorting(num_units=2, durations=[1.0]) + sorting1 = sorting1.rename_units(new_unit_ids=["a", "b"]) + sorting1.set_property(key="complete_bool_property", values=[True, False]) + add_units_table(sorting=sorting1, nwbfile=self.nwbfile) + + expected_property = np.asarray([True, False]) + extracted_property = self.nwbfile.units["complete_bool_property"].data.astype(bool) + assert np.array_equal(extracted_property, expected_property) + + sorting2 = generate_sorting(num_units=2, durations=[1.0]) + sorting2 = sorting2.rename_units(new_unit_ids=["c", "d"]) + + sorting2.set_property(key="incomplete_bool_property", values=[True, False]) + with self.assertRaises(ValueError): + add_units_table(sorting=sorting2, nwbfile=self.nwbfile) + + null_values_for_properties = {"complete_bool_property": False, "incomplete_bool_property": False} + add_units_table(sorting=sorting2, nwbfile=self.nwbfile, null_values_for_properties=null_values_for_properties) + + expected_complete_property = np.asarray([True, False, False, False]) + expected_incomplete_property = np.asarray([False, False, True, False]) + + extracted_complete_property = self.nwbfile.units["complete_bool_property"].data.astype(bool) + extracted_incomplete_property = self.nwbfile.units["incomplete_bool_property"].data.astype(bool) + + assert np.array_equal(extracted_complete_property, expected_complete_property) + assert np.array_equal(extracted_incomplete_property, expected_incomplete_property) + from neuroconv.tools import get_package_version @@ -1487,20 +1546,22 @@ def test_write_recordingless(self): write_electrical_series=True, ) - def test_write_sorting_analyzer_to_file(self): - """This tests that the analyzer is written to file""" - metadata = get_default_nwbfile_metadata() - metadata["NWBFile"]["session_start_time"] = datetime.now() - write_sorting_analyzer( - sorting_analyzer=self.single_segment_analyzer, - nwbfile_path=self.nwbfile_path, - write_electrical_series=True, - metadata=metadata, - ) - with NWBHDF5IO(self.nwbfile_path, "r") as io: - nwbfile = io.read() - self._test_analyzer_write(self.single_segment_analyzer, nwbfile) - self.assertIn("ElectricalSeriesRaw", nwbfile.acquisition) + # def test_write_sorting_analyzer_to_file(self): + # """This tests that the analyzer is written to file""" + # metadata = get_default_nwbfile_metadata() + # metadata["NWBFile"]["session_start_time"] = datetime.now() + + # write_sorting_analyzer( + # sorting_analyzer=self.single_segment_analyzer, + # nwbfile_path=self.nwbfile_path, + # write_electrical_series=True, + # metadata=metadata, + # ) + + # with NWBHDF5IO(self.nwbfile_path, "r") as io: + # nwbfile = io.read() + # self._test_analyzer_write(self.single_segment_analyzer, nwbfile) + # self.assertIn("ElectricalSeriesRaw", nwbfile.acquisition) def test_write_multiple_probes_without_electrical_series(self): """This test that the analyzer is written to different electrode groups"""