Skip to content

Commit

Permalink
Add hdf5 backend support for NwbSortingExtractor II (#2341)
Browse files Browse the repository at this point in the history
* fast recorder

* factory pattern

* docstring

* change fast_mode to bakcned and improve docstring
  • Loading branch information
h-mayorquin authored Dec 15, 2023
1 parent 4f46965 commit 00bcd54
Show file tree
Hide file tree
Showing 2 changed files with 244 additions and 42 deletions.
257 changes: 221 additions & 36 deletions src/spikeinterface/extractors/nwbextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def __init__(
file_path: str | Path | None = None, # provide either this or file
electrical_series_name: str | None = None,
load_time_vector: bool = False,
samples_for_rate_estimation: int = 1000,
samples_for_rate_estimation: int = 1_000,
stream_mode: Optional[Literal["fsspec", "ros3", "remfile"]] = None,
stream_cache_path: str | Path | None = None,
*,
Expand Down Expand Up @@ -399,7 +399,7 @@ def __init__(
class _NWBHDF5RecordingExtractor(BaseRecording):
"""
A RecordingExtractor for NWB files. This uses the hdf5 API to extract the traces and
the metadata and is called by the NwbRecordingExtractor factory. This is faster
the metadata and is called by the NwbRecordingExtractor factory. This should be faster
as it avoids the pynwb validation overhead.
"""

Expand All @@ -408,7 +408,7 @@ def __init__(
file_path: str | Path | None = None, # provide either this or file
electrical_series_name: str | None = None,
load_time_vector: bool = False,
samples_for_rate_estimation: int = 10_0000,
samples_for_rate_estimation: int = 1_000,
stream_mode: Optional[Literal["fsspec", "ros3", "remfile"]] = None,
stream_cache_path: str | Path | None = None,
*,
Expand Down Expand Up @@ -798,45 +798,143 @@ def get_traces(self, start_frame, end_frame, channel_indices):
return traces


class NwbSortingExtractor(BaseSorting):
"""Load an NWBFile as a SortingExtractor.
class _NwbHDF5SortingExtractor(BaseSorting):
def __init__(
self,
*,
file_path: str | Path,
electrical_series_name: str | None = None,
sampling_frequency: float | None = None,
samples_for_rate_estimation: int = 1_000,
stream_mode: str | None = None,
stream_cache_path: str | Path | None = None,
cache: bool = False,
t_start: float | None = None,
):
self.stream_mode = stream_mode
self.stream_cache_path = stream_cache_path

Parameters
----------
file_path: str or Path
Path to NWB file.
electrical_series_name: str or None, default: None
The name of the ElectricalSeries (if multiple ElectricalSeries are present).
sampling_frequency: float or None, default: None
The sampling frequency in Hz (required if no ElectricalSeries is available).
samples_for_rate_estimation: int, default: 1000
The number of timestamp samples to use to estimate the rate.
Used if "rate" is not specified in the ElectricalSeries.
stream_mode : "fsspec" | "ros3" | "remfile" | None, default: None
The streaming mode to use. If None it assumes the file is on the local disk.
cache: bool, default: False
If True, the file is cached in the file passed to stream_cache_path
if False, the file is not cached.
stream_cache_path: str or Path or None, default: None
Local path for caching. If None it uses the system temporary directory.
t_start: float or None, default: None
This is the time at which the corresponding ElectricalSeries start. NWB stores its spikes as times
and the `t_start` is used to convert the times to seconds. Concrently, the returned frames are computed as:
hdf5_file = _read_hdf5_file(
file_path=file_path,
stream_mode=stream_mode,
cache=cache,
stream_cache_path=stream_cache_path,
)

`frames = (times - t_start) * sampling_frequency`.
timestamps = None
self.t_start = t_start

As SpikeInterface always considers the first frame to be at the beginning of the recording independently
of the `t_start`.
if sampling_frequency is None or t_start is None:
# defines the electrical series from where the sorting came from
# important to know the sampling_frequency
available_electrical_series = _NWBHDF5RecordingExtractor.find_electrical_series(hdf5_file)
if electrical_series_name is None:
if len(available_electrical_series) == 1:
electrical_series_name = list(available_electrical_series.keys())[0]
else:
raise ValueError(
"Multiple ElectricalSeries found in the file. "
"Please specify the 'electrical_series_name' argument:"
f"Available options are: {available_electrical_series}."
)
else:
if electrical_series_name not in available_electrical_series:
raise ValueError(
f"'{electrical_series_name}' not found in the file. "
f"Available options are: {available_electrical_series}"
)
self.electrical_series_location = available_electrical_series[electrical_series_name]
electrical_series = hdf5_file[self.electrical_series_location]

When a `t_start` is not provided it will be inferred from the corresponding ElectricalSeries with name equal
to `electrical_series_name`. The `t_start` then will be either the `ElectricalSeries.starting_time` or the
first timestamp in the `ElectricalSeries.timestamps`.
# Get sampling frequency
if "starting_time" in electrical_series.keys():
t_start = electrical_series["starting_time"][()]
sampling_frequency = electrical_series["starting_time"].attrs["rate"]
elif "timestamps" in electrical_series.keys():
timestamps = electrical_series["timestamps"][:]
t_start = timestamps[0]
sampling_frequency = 1.0 / np.median(np.diff(timestamps[:samples_for_rate_estimation]))

self.t_start = t_start

Returns
-------
sorting: NwbSortingExtractor
The sorting extractor for the NWB file.
assert (
sampling_frequency is not None
), "Couldn't load sampling frequency. Please provide it with the 'sampling_frequency' argument"
assert (
self.t_start is not None
), "Couldn't load a starting time for the sorting. Please provide it with the 't_start' argument"

units_table = hdf5_file["units"]
spike_times_data = units_table["spike_times"]
spike_times_index_data = units_table["spike_times_index"]

if "unit_name" in units_table:
unit_ids = units_table["unit_name"]
else:
unit_ids = units_table["id"]

decode_to_string = lambda x: x.decode("utf-8") if isinstance(x, bytes) else x
unit_ids = [decode_to_string(id) for id in unit_ids]
BaseSorting.__init__(self, sampling_frequency=sampling_frequency, unit_ids=unit_ids)

sorting_segment = NwbSortingSegment(
spike_times_data=spike_times_data,
spike_times_index_data=spike_times_index_data,
sampling_frequency=sampling_frequency,
t_start=self.t_start,
)
self.add_sorting_segment(sorting_segment)

# Skip canonical properties and indices
caonical_properties = ["spike_times", "spike_times_index", "unit_name"]
index_properties = [name for name in units_table if name.endswith("_index")]
nested_ragged_array_properties = [name for name in units_table if f"{name}_index_index" in units_table]

skip_properties = caonical_properties + index_properties + nested_ragged_array_properties
properties_to_add = [name for name in units_table if name not in skip_properties]

for property_name in properties_to_add:
data = units_table[property_name]
corresponding_index_name = f"{property_name}_index"
not_ragged_array = corresponding_index_name not in units_table
if not_ragged_array:
values = data[:]
else:
data_index = units_table[corresponding_index_name][:]
index_spacing = np.diff(data_index, prepend=0)
all_index_spacing_are_the_same = np.unique(index_spacing).size == 1
if all_index_spacing_are_the_same:
start_indices = [0] + list(data_index[:-1])
end_indices = list(data_index)
values = [data[start_index:end_index] for start_index, end_index in zip(start_indices, end_indices)]

else:
warnings.warn(f"Skipping {property_name} because of unequal shapes across units")
continue

decode_to_string = lambda x: x.decode("utf-8") if isinstance(x, bytes) else x
values = [decode_to_string(val) for val in values]
self.set_property(property_name, np.asarray(values))

if stream_mode is None and file_path is not None:
file_path = str(Path(file_path).resolve())

self._kwargs = {
"file_path": file_path,
"electrical_series_name": electrical_series_name,
"sampling_frequency": sampling_frequency,
"samples_for_rate_estimation": samples_for_rate_estimation,
"cache": cache,
"stream_mode": stream_mode,
"stream_cache_path": stream_cache_path,
"t_start": self.t_start,
}


class _NwbPynwbSortingExtractor(BaseSorting):
"""
A SortingExtractor for NWB files. This uses the NWB API to extract the traces and
the metadata and is called by the NwbSortingExtractor factory.
"""

extractor_name = "NwbSorting"
Expand Down Expand Up @@ -957,6 +1055,93 @@ def __init__(
}


class NwbSortingExtractor(BaseSorting):
"""Load an NWBFile as a SortingExtractor.
Parameters
----------
file_path: str or Path
Path to NWB file.
electrical_series_name: str or None, default: None
The name of the ElectricalSeries (if multiple ElectricalSeries are present).
sampling_frequency: float or None, default: None
The sampling frequency in Hz (required if no ElectricalSeries is available).
samples_for_rate_estimation: int, default: 100000
The number of timestamp samples to use to estimate the rate.
Used if "rate" is not specified in the ElectricalSeries.
stream_mode : "fsspec" | "ros3" | "remfile" | None, default: None
The streaming mode to use. If None it assumes the file is on the local disk.
cache: bool, default: False
If True, the file is cached in the file passed to stream_cache_path
if False, the file is not cached.
stream_cache_path: str or Path or None, default: None
Local path for caching. If None it uses the system temporary directory.
t_start: float or None, default: None
This is the time at which the corresponding ElectricalSeries start. NWB stores its spikes as times
and the `t_start` is used to convert the times to seconds. Concrently, the returned frames are computed as:
`frames = (times - t_start) * sampling_frequency`.
As SpikeInterface always considers the first frame to be at the beginning of the recording independently
of the `t_start`.
When a `t_start` is not provided it will be inferred from the corresponding ElectricalSeries with name equal
to `electrical_series_name`. The `t_start` then will be either the `ElectricalSeries.starting_time` or the
first timestamp in the `ElectricalSeries.timestamps`.
use_pynwb: bool, default: False
Uses the pynwb library to read the NWB file. Setting this to False, the default, uses h5py
to read the file. Using h5py can improve performance by bypassing some of the PyNWB validations.
Returns
-------
sorting: NwbSortingExtractor
The sorting extractor for the NWB file.
"""

extractor_name = "NwbSorting"
mode = "file"
installation_mesg = "To use the Nwb extractors, install pynwb: \n\n pip install pynwb\n\n"
name = "nwb"

def __new__(
self,
file_path: str | Path,
electrical_series_name: str | None = None,
sampling_frequency: float | None = None,
samples_for_rate_estimation: int = 1000,
stream_mode: str | None = None,
stream_cache_path: str | Path | None = None,
*,
t_start: float | None = None,
cache: bool = False,
use_pynwb: bool = False,
):
if use_pynwb:
extractor = _NwbPynwbSortingExtractor(
file_path=file_path,
electrical_series_name=electrical_series_name,
sampling_frequency=sampling_frequency,
samples_for_rate_estimation=samples_for_rate_estimation,
stream_mode=stream_mode,
stream_cache_path=stream_cache_path,
cache=cache,
t_start=t_start,
)

else:
extractor = _NwbHDF5SortingExtractor(
file_path=file_path,
electrical_series_name=electrical_series_name,
sampling_frequency=sampling_frequency,
samples_for_rate_estimation=samples_for_rate_estimation,
stream_mode=stream_mode,
stream_cache_path=stream_cache_path,
cache=cache,
t_start=t_start,
)

return extractor


class NwbSortingSegment(BaseSortingSegment):
def __init__(self, spike_times_data, spike_times_index_data, sampling_frequency: float, t_start: float):
BaseSortingSegment.__init__(self)
Expand Down
29 changes: 23 additions & 6 deletions src/spikeinterface/extractors/tests/test_nwbextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,8 @@ def test_that_hdf5_and_pynwb_extractors_return_the_same_data(path_to_nwbfile, el
check_recordings_equal(recording_extractor_hdf5, recording_extractor_pynwb)


def test_sorting_extraction_of_ragged_arrays(tmp_path):
@pytest.mark.parametrize("use_pynwb", [True, False])
def test_sorting_extraction_of_ragged_arrays(tmp_path, use_pynwb):
nwbfile = mock_NWBFile()

# Add the spikes
Expand Down Expand Up @@ -267,7 +268,12 @@ def test_sorting_extraction_of_ragged_arrays(tmp_path):
with NWBHDF5IO(path=file_path, mode="w") as io:
io.write(nwbfile)

sorting_extractor = NwbSortingExtractor(file_path=file_path, sampling_frequency=10.0, t_start=0)
sorting_extractor = NwbSortingExtractor(
file_path=file_path,
sampling_frequency=10.0,
t_start=0,
use_pynwb=use_pynwb,
)

units_ids = sorting_extractor.get_unit_ids()

Expand All @@ -286,7 +292,8 @@ def test_sorting_extraction_of_ragged_arrays(tmp_path):
np.testing.assert_allclose(extracted_spike_times_b, spike_times_b)


def test_sorting_extraction_start_time(tmp_path):
@pytest.mark.parametrize("use_pynwb", [True, False])
def test_sorting_extraction_start_time(tmp_path, use_pynwb):
nwbfile = mock_NWBFile()

# Add the spikes
Expand All @@ -303,7 +310,12 @@ def test_sorting_extraction_start_time(tmp_path):
with NWBHDF5IO(path=file_path, mode="w") as io:
io.write(nwbfile)

sorting_extractor = NwbSortingExtractor(file_path=file_path, sampling_frequency=sampling_frequency, t_start=t_start)
sorting_extractor = NwbSortingExtractor(
file_path=file_path,
sampling_frequency=sampling_frequency,
t_start=t_start,
use_pynwb=use_pynwb,
)

# Test frames
extracted_frames0 = sorting_extractor.get_unit_spike_train(unit_id=0, return_times=False)
Expand All @@ -324,7 +336,8 @@ def test_sorting_extraction_start_time(tmp_path):
np.testing.assert_allclose(extracted_spike_times1, expected_spike_times1)


def test_sorting_extraction_start_time_from_series(tmp_path):
@pytest.mark.parametrize("use_pynwb", [True, False])
def test_sorting_extraction_start_time_from_series(tmp_path, use_pynwb):
nwbfile = mock_NWBFile()
electrical_series_name = "ElectricalSeries"
t_start = 10.0
Expand All @@ -350,7 +363,11 @@ def test_sorting_extraction_start_time_from_series(tmp_path):
with NWBHDF5IO(path=file_path, mode="w") as io:
io.write(nwbfile)

sorting_extractor = NwbSortingExtractor(file_path=file_path, electrical_series_name=electrical_series_name)
sorting_extractor = NwbSortingExtractor(
file_path=file_path,
electrical_series_name=electrical_series_name,
use_pynwb=use_pynwb,
)

# Test frames
extracted_frames0 = sorting_extractor.get_unit_spike_train(unit_id=0, return_times=False)
Expand Down

0 comments on commit 00bcd54

Please sign in to comment.