From 00bcd547d3352132477e0c9b726c7d39078e0d45 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Fri, 15 Dec 2023 13:55:39 +0100 Subject: [PATCH] Add hdf5 backend support for `NwbSortingExtractor` II (#2341) * fast recorder * factory pattern * docstring * change fast_mode to bakcned and improve docstring --- .../extractors/nwbextractors.py | 257 +++++++++++++++--- .../extractors/tests/test_nwbextractors.py | 29 +- 2 files changed, 244 insertions(+), 42 deletions(-) diff --git a/src/spikeinterface/extractors/nwbextractors.py b/src/spikeinterface/extractors/nwbextractors.py index 6d3b49cb3f..ef570a5f4e 100644 --- a/src/spikeinterface/extractors/nwbextractors.py +++ b/src/spikeinterface/extractors/nwbextractors.py @@ -202,7 +202,7 @@ def __init__( file_path: str | Path | None = None, # provide either this or file electrical_series_name: str | None = None, load_time_vector: bool = False, - samples_for_rate_estimation: int = 1000, + samples_for_rate_estimation: int = 1_000, stream_mode: Optional[Literal["fsspec", "ros3", "remfile"]] = None, stream_cache_path: str | Path | None = None, *, @@ -399,7 +399,7 @@ def __init__( class _NWBHDF5RecordingExtractor(BaseRecording): """ A RecordingExtractor for NWB files. This uses the hdf5 API to extract the traces and - the metadata and is called by the NwbRecordingExtractor factory. This is faster + the metadata and is called by the NwbRecordingExtractor factory. This should be faster as it avoids the pynwb validation overhead. """ @@ -408,7 +408,7 @@ def __init__( file_path: str | Path | None = None, # provide either this or file electrical_series_name: str | None = None, load_time_vector: bool = False, - samples_for_rate_estimation: int = 10_0000, + samples_for_rate_estimation: int = 1_000, stream_mode: Optional[Literal["fsspec", "ros3", "remfile"]] = None, stream_cache_path: str | Path | None = None, *, @@ -798,45 +798,143 @@ def get_traces(self, start_frame, end_frame, channel_indices): return traces -class NwbSortingExtractor(BaseSorting): - """Load an NWBFile as a SortingExtractor. +class _NwbHDF5SortingExtractor(BaseSorting): + def __init__( + self, + *, + file_path: str | Path, + electrical_series_name: str | None = None, + sampling_frequency: float | None = None, + samples_for_rate_estimation: int = 1_000, + stream_mode: str | None = None, + stream_cache_path: str | Path | None = None, + cache: bool = False, + t_start: float | None = None, + ): + self.stream_mode = stream_mode + self.stream_cache_path = stream_cache_path - Parameters - ---------- - file_path: str or Path - Path to NWB file. - electrical_series_name: str or None, default: None - The name of the ElectricalSeries (if multiple ElectricalSeries are present). - sampling_frequency: float or None, default: None - The sampling frequency in Hz (required if no ElectricalSeries is available). - samples_for_rate_estimation: int, default: 1000 - The number of timestamp samples to use to estimate the rate. - Used if "rate" is not specified in the ElectricalSeries. - stream_mode : "fsspec" | "ros3" | "remfile" | None, default: None - The streaming mode to use. If None it assumes the file is on the local disk. - cache: bool, default: False - If True, the file is cached in the file passed to stream_cache_path - if False, the file is not cached. - stream_cache_path: str or Path or None, default: None - Local path for caching. If None it uses the system temporary directory. - t_start: float or None, default: None - This is the time at which the corresponding ElectricalSeries start. NWB stores its spikes as times - and the `t_start` is used to convert the times to seconds. Concrently, the returned frames are computed as: + hdf5_file = _read_hdf5_file( + file_path=file_path, + stream_mode=stream_mode, + cache=cache, + stream_cache_path=stream_cache_path, + ) - `frames = (times - t_start) * sampling_frequency`. + timestamps = None + self.t_start = t_start - As SpikeInterface always considers the first frame to be at the beginning of the recording independently - of the `t_start`. + if sampling_frequency is None or t_start is None: + # defines the electrical series from where the sorting came from + # important to know the sampling_frequency + available_electrical_series = _NWBHDF5RecordingExtractor.find_electrical_series(hdf5_file) + if electrical_series_name is None: + if len(available_electrical_series) == 1: + electrical_series_name = list(available_electrical_series.keys())[0] + else: + raise ValueError( + "Multiple ElectricalSeries found in the file. " + "Please specify the 'electrical_series_name' argument:" + f"Available options are: {available_electrical_series}." + ) + else: + if electrical_series_name not in available_electrical_series: + raise ValueError( + f"'{electrical_series_name}' not found in the file. " + f"Available options are: {available_electrical_series}" + ) + self.electrical_series_location = available_electrical_series[electrical_series_name] + electrical_series = hdf5_file[self.electrical_series_location] - When a `t_start` is not provided it will be inferred from the corresponding ElectricalSeries with name equal - to `electrical_series_name`. The `t_start` then will be either the `ElectricalSeries.starting_time` or the - first timestamp in the `ElectricalSeries.timestamps`. + # Get sampling frequency + if "starting_time" in electrical_series.keys(): + t_start = electrical_series["starting_time"][()] + sampling_frequency = electrical_series["starting_time"].attrs["rate"] + elif "timestamps" in electrical_series.keys(): + timestamps = electrical_series["timestamps"][:] + t_start = timestamps[0] + sampling_frequency = 1.0 / np.median(np.diff(timestamps[:samples_for_rate_estimation])) + self.t_start = t_start - Returns - ------- - sorting: NwbSortingExtractor - The sorting extractor for the NWB file. + assert ( + sampling_frequency is not None + ), "Couldn't load sampling frequency. Please provide it with the 'sampling_frequency' argument" + assert ( + self.t_start is not None + ), "Couldn't load a starting time for the sorting. Please provide it with the 't_start' argument" + + units_table = hdf5_file["units"] + spike_times_data = units_table["spike_times"] + spike_times_index_data = units_table["spike_times_index"] + + if "unit_name" in units_table: + unit_ids = units_table["unit_name"] + else: + unit_ids = units_table["id"] + + decode_to_string = lambda x: x.decode("utf-8") if isinstance(x, bytes) else x + unit_ids = [decode_to_string(id) for id in unit_ids] + BaseSorting.__init__(self, sampling_frequency=sampling_frequency, unit_ids=unit_ids) + + sorting_segment = NwbSortingSegment( + spike_times_data=spike_times_data, + spike_times_index_data=spike_times_index_data, + sampling_frequency=sampling_frequency, + t_start=self.t_start, + ) + self.add_sorting_segment(sorting_segment) + + # Skip canonical properties and indices + caonical_properties = ["spike_times", "spike_times_index", "unit_name"] + index_properties = [name for name in units_table if name.endswith("_index")] + nested_ragged_array_properties = [name for name in units_table if f"{name}_index_index" in units_table] + + skip_properties = caonical_properties + index_properties + nested_ragged_array_properties + properties_to_add = [name for name in units_table if name not in skip_properties] + + for property_name in properties_to_add: + data = units_table[property_name] + corresponding_index_name = f"{property_name}_index" + not_ragged_array = corresponding_index_name not in units_table + if not_ragged_array: + values = data[:] + else: + data_index = units_table[corresponding_index_name][:] + index_spacing = np.diff(data_index, prepend=0) + all_index_spacing_are_the_same = np.unique(index_spacing).size == 1 + if all_index_spacing_are_the_same: + start_indices = [0] + list(data_index[:-1]) + end_indices = list(data_index) + values = [data[start_index:end_index] for start_index, end_index in zip(start_indices, end_indices)] + + else: + warnings.warn(f"Skipping {property_name} because of unequal shapes across units") + continue + + decode_to_string = lambda x: x.decode("utf-8") if isinstance(x, bytes) else x + values = [decode_to_string(val) for val in values] + self.set_property(property_name, np.asarray(values)) + + if stream_mode is None and file_path is not None: + file_path = str(Path(file_path).resolve()) + + self._kwargs = { + "file_path": file_path, + "electrical_series_name": electrical_series_name, + "sampling_frequency": sampling_frequency, + "samples_for_rate_estimation": samples_for_rate_estimation, + "cache": cache, + "stream_mode": stream_mode, + "stream_cache_path": stream_cache_path, + "t_start": self.t_start, + } + + +class _NwbPynwbSortingExtractor(BaseSorting): + """ + A SortingExtractor for NWB files. This uses the NWB API to extract the traces and + the metadata and is called by the NwbSortingExtractor factory. """ extractor_name = "NwbSorting" @@ -957,6 +1055,93 @@ def __init__( } +class NwbSortingExtractor(BaseSorting): + """Load an NWBFile as a SortingExtractor. + Parameters + ---------- + file_path: str or Path + Path to NWB file. + electrical_series_name: str or None, default: None + The name of the ElectricalSeries (if multiple ElectricalSeries are present). + sampling_frequency: float or None, default: None + The sampling frequency in Hz (required if no ElectricalSeries is available). + samples_for_rate_estimation: int, default: 100000 + The number of timestamp samples to use to estimate the rate. + Used if "rate" is not specified in the ElectricalSeries. + stream_mode : "fsspec" | "ros3" | "remfile" | None, default: None + The streaming mode to use. If None it assumes the file is on the local disk. + cache: bool, default: False + If True, the file is cached in the file passed to stream_cache_path + if False, the file is not cached. + stream_cache_path: str or Path or None, default: None + Local path for caching. If None it uses the system temporary directory. + t_start: float or None, default: None + This is the time at which the corresponding ElectricalSeries start. NWB stores its spikes as times + and the `t_start` is used to convert the times to seconds. Concrently, the returned frames are computed as: + + `frames = (times - t_start) * sampling_frequency`. + + As SpikeInterface always considers the first frame to be at the beginning of the recording independently + of the `t_start`. + + When a `t_start` is not provided it will be inferred from the corresponding ElectricalSeries with name equal + to `electrical_series_name`. The `t_start` then will be either the `ElectricalSeries.starting_time` or the + first timestamp in the `ElectricalSeries.timestamps`. + + use_pynwb: bool, default: False + Uses the pynwb library to read the NWB file. Setting this to False, the default, uses h5py + to read the file. Using h5py can improve performance by bypassing some of the PyNWB validations. + Returns + ------- + sorting: NwbSortingExtractor + The sorting extractor for the NWB file. + """ + + extractor_name = "NwbSorting" + mode = "file" + installation_mesg = "To use the Nwb extractors, install pynwb: \n\n pip install pynwb\n\n" + name = "nwb" + + def __new__( + self, + file_path: str | Path, + electrical_series_name: str | None = None, + sampling_frequency: float | None = None, + samples_for_rate_estimation: int = 1000, + stream_mode: str | None = None, + stream_cache_path: str | Path | None = None, + *, + t_start: float | None = None, + cache: bool = False, + use_pynwb: bool = False, + ): + if use_pynwb: + extractor = _NwbPynwbSortingExtractor( + file_path=file_path, + electrical_series_name=electrical_series_name, + sampling_frequency=sampling_frequency, + samples_for_rate_estimation=samples_for_rate_estimation, + stream_mode=stream_mode, + stream_cache_path=stream_cache_path, + cache=cache, + t_start=t_start, + ) + + else: + extractor = _NwbHDF5SortingExtractor( + file_path=file_path, + electrical_series_name=electrical_series_name, + sampling_frequency=sampling_frequency, + samples_for_rate_estimation=samples_for_rate_estimation, + stream_mode=stream_mode, + stream_cache_path=stream_cache_path, + cache=cache, + t_start=t_start, + ) + + return extractor + + class NwbSortingSegment(BaseSortingSegment): def __init__(self, spike_times_data, spike_times_index_data, sampling_frequency: float, t_start: float): BaseSortingSegment.__init__(self) diff --git a/src/spikeinterface/extractors/tests/test_nwbextractors.py b/src/spikeinterface/extractors/tests/test_nwbextractors.py index 4a076fac7f..a3481f8d5d 100644 --- a/src/spikeinterface/extractors/tests/test_nwbextractors.py +++ b/src/spikeinterface/extractors/tests/test_nwbextractors.py @@ -226,7 +226,8 @@ def test_that_hdf5_and_pynwb_extractors_return_the_same_data(path_to_nwbfile, el check_recordings_equal(recording_extractor_hdf5, recording_extractor_pynwb) -def test_sorting_extraction_of_ragged_arrays(tmp_path): +@pytest.mark.parametrize("use_pynwb", [True, False]) +def test_sorting_extraction_of_ragged_arrays(tmp_path, use_pynwb): nwbfile = mock_NWBFile() # Add the spikes @@ -267,7 +268,12 @@ def test_sorting_extraction_of_ragged_arrays(tmp_path): with NWBHDF5IO(path=file_path, mode="w") as io: io.write(nwbfile) - sorting_extractor = NwbSortingExtractor(file_path=file_path, sampling_frequency=10.0, t_start=0) + sorting_extractor = NwbSortingExtractor( + file_path=file_path, + sampling_frequency=10.0, + t_start=0, + use_pynwb=use_pynwb, + ) units_ids = sorting_extractor.get_unit_ids() @@ -286,7 +292,8 @@ def test_sorting_extraction_of_ragged_arrays(tmp_path): np.testing.assert_allclose(extracted_spike_times_b, spike_times_b) -def test_sorting_extraction_start_time(tmp_path): +@pytest.mark.parametrize("use_pynwb", [True, False]) +def test_sorting_extraction_start_time(tmp_path, use_pynwb): nwbfile = mock_NWBFile() # Add the spikes @@ -303,7 +310,12 @@ def test_sorting_extraction_start_time(tmp_path): with NWBHDF5IO(path=file_path, mode="w") as io: io.write(nwbfile) - sorting_extractor = NwbSortingExtractor(file_path=file_path, sampling_frequency=sampling_frequency, t_start=t_start) + sorting_extractor = NwbSortingExtractor( + file_path=file_path, + sampling_frequency=sampling_frequency, + t_start=t_start, + use_pynwb=use_pynwb, + ) # Test frames extracted_frames0 = sorting_extractor.get_unit_spike_train(unit_id=0, return_times=False) @@ -324,7 +336,8 @@ def test_sorting_extraction_start_time(tmp_path): np.testing.assert_allclose(extracted_spike_times1, expected_spike_times1) -def test_sorting_extraction_start_time_from_series(tmp_path): +@pytest.mark.parametrize("use_pynwb", [True, False]) +def test_sorting_extraction_start_time_from_series(tmp_path, use_pynwb): nwbfile = mock_NWBFile() electrical_series_name = "ElectricalSeries" t_start = 10.0 @@ -350,7 +363,11 @@ def test_sorting_extraction_start_time_from_series(tmp_path): with NWBHDF5IO(path=file_path, mode="w") as io: io.write(nwbfile) - sorting_extractor = NwbSortingExtractor(file_path=file_path, electrical_series_name=electrical_series_name) + sorting_extractor = NwbSortingExtractor( + file_path=file_path, + electrical_series_name=electrical_series_name, + use_pynwb=use_pynwb, + ) # Test frames extracted_frames0 = sorting_extractor.get_unit_spike_train(unit_id=0, return_times=False)