diff --git a/src/spikeinterface/extractors/nwbextractors.py b/src/spikeinterface/extractors/nwbextractors.py index 24e400bdaf..d5f5ac60a5 100644 --- a/src/spikeinterface/extractors/nwbextractors.py +++ b/src/spikeinterface/extractors/nwbextractors.py @@ -68,6 +68,61 @@ def retrieve_electrical_series(nwbfile: NWBFile, electrical_series_name: Optiona return electrical_series +def _read_hdf5_file( + *, + file_path: str | Path | None, + file: BinaryIO | None = None, + stream_mode: Literal["ffspec", "ros3", "remfile"] | None = None, + cache: bool = False, + stream_cache_path: str | Path | None = None, +): + import h5py + + if stream_mode == "fsspec": + import fsspec + from fsspec.implementations.cached import CachingFileSystem + + assert file_path is not None, "file_path must be specified when using stream_mode='fsspec'" + + fsspec_file_system = fsspec.filesystem("http") + + if cache: + stream_cache_path = stream_cache_path if stream_cache_path is not None else str(get_global_tmp_folder()) + caching_file_system = CachingFileSystem( + fs=fsspec_file_system, + cache_storage=str(stream_cache_path), + ) + ffspec_file = caching_file_system.open(path=file_path, mode="rb") + else: + ffspec_file = fsspec_file_system.open(file_path, "rb") + + hdf5_file = h5py.File(name=ffspec_file, mode="r") + + elif stream_mode == "ros3": + assert file_path is not None, "file_path must be specified when using stream_mode='ros3'" + + drivers = h5py.registered_drivers() + assertion_msg = "ROS3 support not enbabled, use: install -c conda-forge h5py>=3.2 to enable streaming" + assert "ros3" in drivers, assertion_msg + hdf5_file = h5py.File(name=file_path, mode="r", driver="ros3") + + elif stream_mode == "remfile": + import remfile + + assert file_path is not None, "file_path must be specified when using stream_mode='remfile'" + rfile = remfile.File(file_path) + hdf5_file = h5py.File(rfile, "r") + + elif file_path is not None: + file_path = str(Path(file_path).resolve()) + hdf5_file = h5py.File(name=file_path, mode="r") + else: + assert file is not None, "Unexpected, file is None" + hdf5_file = h5py.File(file, "r") + + return hdf5_file + + def read_nwbfile( *, file_path: str | Path | None, @@ -119,59 +174,14 @@ def read_nwbfile( if file_path is None and file is None: raise ValueError("Provide either file_path or file") - if stream_mode == "fsspec": - import fsspec - import h5py - - from fsspec.implementations.cached import CachingFileSystem - - assert file_path is not None, "file_path must be specified when using stream_mode='fsspec'" - - fsspec_file_system = fsspec.filesystem("http") - - if cache: - stream_cache_path = stream_cache_path if stream_cache_path is not None else str(get_global_tmp_folder()) - caching_file_system = CachingFileSystem( - fs=fsspec_file_system, - cache_storage=str(stream_cache_path), - ) - ffspec_file = caching_file_system.open(path=file_path, mode="rb") - else: - ffspec_file = fsspec_file_system.open(file_path, "rb") - - file = h5py.File(ffspec_file, "r") - io = NWBHDF5IO(file=file, mode="r", load_namespaces=True) - - elif stream_mode == "ros3": - import h5py - - assert file_path is not None, "file_path must be specified when using stream_mode='ros3'" - - drivers = h5py.registered_drivers() - assertion_msg = "ROS3 support not enbabled, use: install -c conda-forge h5py>=3.2 to enable streaming" - assert "ros3" in drivers, assertion_msg - io = NWBHDF5IO(path=file_path, mode="r", load_namespaces=True, driver="ros3") - - elif stream_mode == "remfile": - import remfile - import h5py - - assert file_path is not None, "file_path must be specified when using stream_mode='remfile'" - rfile = remfile.File(file_path) - h5_file = h5py.File(rfile, "r") - io = NWBHDF5IO(file=h5_file, mode="r", load_namespaces=True) - - elif file_path is not None: - file_path = str(Path(file_path).absolute()) - io = NWBHDF5IO(path=file_path, mode="r", load_namespaces=True) - - else: - import h5py - - assert file is not None, "Unexpected, file is None" - h5_file = h5py.File(file, "r") - io = NWBHDF5IO(file=h5_file, mode="r", load_namespaces=True) - + hdf5_file = _read_hdf5_file( + file_path=file_path, + file=file, + stream_mode=stream_mode, + cache=cache, + stream_cache_path=stream_cache_path, + ) + io = NWBHDF5IO(file=hdf5_file, mode="r", load_namespaces=True) nwbfile = io.read() return nwbfile @@ -464,7 +474,9 @@ def get_traces(self, start_frame, end_frame, channel_indices): end_frame = self.get_num_samples() electrical_series_data = self.electrical_series.data - if isinstance(channel_indices, slice): + if electrical_series_data.ndim == 1: + traces = electrical_series_data[start_frame:end_frame][:, np.newaxis] + elif isinstance(channel_indices, slice): traces = electrical_series_data[start_frame:end_frame, channel_indices] else: # channel_indices is np.ndarray