Skip to content

Commit

Permalink
Merge pull request #2294 from h-mayorquin/edge_case_nwb
Browse files Browse the repository at this point in the history
Refactor NWB extractor to separate reading hdf5 file from reading nwb
  • Loading branch information
alejoe91 authored Dec 6, 2023
2 parents 13036f1 + 9e8b84f commit 170eefb
Showing 1 changed file with 66 additions and 54 deletions.
120 changes: 66 additions & 54 deletions src/spikeinterface/extractors/nwbextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,61 @@ def retrieve_electrical_series(nwbfile: NWBFile, electrical_series_name: Optiona
return electrical_series


def _read_hdf5_file(
*,
file_path: str | Path | None,
file: BinaryIO | None = None,
stream_mode: Literal["ffspec", "ros3", "remfile"] | None = None,
cache: bool = False,
stream_cache_path: str | Path | None = None,
):
import h5py

if stream_mode == "fsspec":
import fsspec
from fsspec.implementations.cached import CachingFileSystem

assert file_path is not None, "file_path must be specified when using stream_mode='fsspec'"

fsspec_file_system = fsspec.filesystem("http")

if cache:
stream_cache_path = stream_cache_path if stream_cache_path is not None else str(get_global_tmp_folder())
caching_file_system = CachingFileSystem(
fs=fsspec_file_system,
cache_storage=str(stream_cache_path),
)
ffspec_file = caching_file_system.open(path=file_path, mode="rb")
else:
ffspec_file = fsspec_file_system.open(file_path, "rb")

hdf5_file = h5py.File(name=ffspec_file, mode="r")

elif stream_mode == "ros3":
assert file_path is not None, "file_path must be specified when using stream_mode='ros3'"

drivers = h5py.registered_drivers()
assertion_msg = "ROS3 support not enbabled, use: install -c conda-forge h5py>=3.2 to enable streaming"
assert "ros3" in drivers, assertion_msg
hdf5_file = h5py.File(name=file_path, mode="r", driver="ros3")

elif stream_mode == "remfile":
import remfile

assert file_path is not None, "file_path must be specified when using stream_mode='remfile'"
rfile = remfile.File(file_path)
hdf5_file = h5py.File(rfile, "r")

elif file_path is not None:
file_path = str(Path(file_path).resolve())
hdf5_file = h5py.File(name=file_path, mode="r")
else:
assert file is not None, "Unexpected, file is None"
hdf5_file = h5py.File(file, "r")

return hdf5_file


def read_nwbfile(
*,
file_path: str | Path | None,
Expand Down Expand Up @@ -119,59 +174,14 @@ def read_nwbfile(
if file_path is None and file is None:
raise ValueError("Provide either file_path or file")

if stream_mode == "fsspec":
import fsspec
import h5py

from fsspec.implementations.cached import CachingFileSystem

assert file_path is not None, "file_path must be specified when using stream_mode='fsspec'"

fsspec_file_system = fsspec.filesystem("http")

if cache:
stream_cache_path = stream_cache_path if stream_cache_path is not None else str(get_global_tmp_folder())
caching_file_system = CachingFileSystem(
fs=fsspec_file_system,
cache_storage=str(stream_cache_path),
)
ffspec_file = caching_file_system.open(path=file_path, mode="rb")
else:
ffspec_file = fsspec_file_system.open(file_path, "rb")

file = h5py.File(ffspec_file, "r")
io = NWBHDF5IO(file=file, mode="r", load_namespaces=True)

elif stream_mode == "ros3":
import h5py

assert file_path is not None, "file_path must be specified when using stream_mode='ros3'"

drivers = h5py.registered_drivers()
assertion_msg = "ROS3 support not enbabled, use: install -c conda-forge h5py>=3.2 to enable streaming"
assert "ros3" in drivers, assertion_msg
io = NWBHDF5IO(path=file_path, mode="r", load_namespaces=True, driver="ros3")

elif stream_mode == "remfile":
import remfile
import h5py

assert file_path is not None, "file_path must be specified when using stream_mode='remfile'"
rfile = remfile.File(file_path)
h5_file = h5py.File(rfile, "r")
io = NWBHDF5IO(file=h5_file, mode="r", load_namespaces=True)

elif file_path is not None:
file_path = str(Path(file_path).absolute())
io = NWBHDF5IO(path=file_path, mode="r", load_namespaces=True)

else:
import h5py

assert file is not None, "Unexpected, file is None"
h5_file = h5py.File(file, "r")
io = NWBHDF5IO(file=h5_file, mode="r", load_namespaces=True)

hdf5_file = _read_hdf5_file(
file_path=file_path,
file=file,
stream_mode=stream_mode,
cache=cache,
stream_cache_path=stream_cache_path,
)
io = NWBHDF5IO(file=hdf5_file, mode="r", load_namespaces=True)
nwbfile = io.read()
return nwbfile

Expand Down Expand Up @@ -464,7 +474,9 @@ def get_traces(self, start_frame, end_frame, channel_indices):
end_frame = self.get_num_samples()

electrical_series_data = self.electrical_series.data
if isinstance(channel_indices, slice):
if electrical_series_data.ndim == 1:
traces = electrical_series_data[start_frame:end_frame][:, np.newaxis]
elif isinstance(channel_indices, slice):
traces = electrical_series_data[start_frame:end_frame, channel_indices]
else:
# channel_indices is np.ndarray
Expand Down

0 comments on commit 170eefb

Please sign in to comment.