Skip to content

Commit

Permalink
Merge pull request #2169 from magland/nwb-remfile
Browse files Browse the repository at this point in the history
Support remfile and file-like in nwb rec extractor
  • Loading branch information
alejoe91 authored Nov 27, 2023
2 parents 0da79b6 + 273766b commit ee88ef3
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 17 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ streaming_extractors = [
"aiohttp",
"requests",
"pynwb>=2.3.0",
"remfile"
]

full = [
Expand Down
79 changes: 62 additions & 17 deletions src/spikeinterface/extractors/nwbextractors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations
from pathlib import Path
from typing import Union, List, Optional, Literal, Dict
from typing import Union, List, Optional, Literal, Dict, BinaryIO

import numpy as np

Expand Down Expand Up @@ -69,8 +69,10 @@ def retrieve_electrical_series(nwbfile: NWBFile, electrical_series_name: Optiona


def read_nwbfile(
file_path: str | Path,
stream_mode: Literal["ffspec", "ros3"] | None = None,
*,
file_path: str | Path | None,
file: BinaryIO | None = None,
stream_mode: Literal["ffspec", "ros3", "remfile"] | None = None,
cache: bool = True,
stream_cache_path: str | Path | None = None,
) -> NWBFile:
Expand All @@ -79,9 +81,11 @@ def read_nwbfile(
Parameters
----------
file_path : Path, str
The path to the NWB file.
stream_mode : "fsspec" or "ros3" or None, default: None
file_path : Path, str or None
The path to the NWB file. Either provide this or file.
file : file-like object or None
The file-like object to read from. Either provide this or file_path.
stream_mode : "fsspec" | "ros3" | "remfile" | None, default: None
The streaming mode to use. If None it assumes the file is on the local disk.
cache: bool, default: True
If True, the file is cached in the file passed to stream_cache_path
Expand Down Expand Up @@ -110,12 +114,19 @@ def read_nwbfile(
"""
from pynwb import NWBHDF5IO

if file_path is not None and file is not None:
raise ValueError("Provide either file_path or file, not both")
if file_path is None and file is None:
raise ValueError("Provide either file_path or file")

if stream_mode == "fsspec":
import fsspec
import h5py

from fsspec.implementations.cached import CachingFileSystem

assert file_path is not None, "file_path must be specified when using stream_mode='fsspec'"

fsspec_file_system = fsspec.filesystem("http")

if cache:
Expand All @@ -134,15 +145,33 @@ def read_nwbfile(
elif stream_mode == "ros3":
import h5py

assert file_path is not None, "file_path must be specified when using stream_mode='ros3'"

drivers = h5py.registered_drivers()
assertion_msg = "ROS3 support not enbabled, use: install -c conda-forge h5py>=3.2 to enable streaming"
assert "ros3" in drivers, assertion_msg
io = NWBHDF5IO(path=file_path, mode="r", load_namespaces=True, driver="ros3")

else:
elif stream_mode == "remfile":
import remfile
import h5py

assert file_path is not None, "file_path must be specified when using stream_mode='remfile'"
rfile = remfile.File(file_path)
h5_file = h5py.File(rfile, "r")
io = NWBHDF5IO(file=h5_file, mode="r", load_namespaces=True)

elif file_path is not None:
file_path = str(Path(file_path).absolute())
io = NWBHDF5IO(path=file_path, mode="r", load_namespaces=True)

else:
import h5py

assert file is not None, "Unexpected, file is None"
h5_file = h5py.File(file, "r")
io = NWBHDF5IO(file=h5_file, mode="r", load_namespaces=True)

nwbfile = io.read()
return nwbfile

Expand All @@ -152,10 +181,12 @@ class NwbRecordingExtractor(BaseRecording):
Parameters
----------
file_path: str or Path
Path to NWB file or s3 url.
file_path: str, Path, or None
Path to NWB file or s3 url (or None if using file instead)
electrical_series_name: str or None, default: None
The name of the ElectricalSeries. Used if multiple ElectricalSeries are present.
file: file-like object or None, default: None
File-like object to read from (if None, file_path must be specified)
load_time_vector: bool, default: False
If True, the time vector is loaded to the recording object.
samples_for_rate_estimation: int, default: 100000
Expand All @@ -167,7 +198,7 @@ class NwbRecordingExtractor(BaseRecording):
If True, the file is cached in the file passed to stream_cache_path
if False, the file is not cached.
stream_cache_path: str or Path or None, default: None
Local path for caching. If None it uses cwd
Local path for caching. If None it uses the current working directory (cwd)
Returns
-------
Expand Down Expand Up @@ -202,27 +233,34 @@ class NwbRecordingExtractor(BaseRecording):

def __init__(
self,
file_path: str | Path,
electrical_series_name: str = None,
file_path: str | Path | None = None, # provide either this or file
electrical_series_name: str | None = None,
load_time_vector: bool = False,
samples_for_rate_estimation: int = 100000,
cache: bool = True,
stream_mode: Optional[Literal["fsspec", "ros3"]] = None,
stream_mode: Optional[Literal["fsspec", "ros3", "remfile"]] = None,
stream_cache_path: str | Path | None = None,
*,
file: BinaryIO | None = None, # file-like - provide either this or file_path
):
try:
from pynwb import NWBHDF5IO, NWBFile
from pynwb.ecephys import ElectrodeGroup
except ImportError:
raise ImportError(self.installation_mesg)

if file_path is not None and file is not None:
raise ValueError("Provide either file_path or file, not both")
if file_path is None and file is None:
raise ValueError("Provide either file_path or file")

self.stream_mode = stream_mode
self.stream_cache_path = stream_cache_path
self._electrical_series_name = electrical_series_name

self.file_path = file_path
self._nwbfile = read_nwbfile(
file_path=file_path, stream_mode=stream_mode, cache=cache, stream_cache_path=stream_cache_path
file_path=file_path, file=file, stream_mode=stream_mode, cache=cache, stream_cache_path=stream_cache_path
)
electrical_series = retrieve_electrical_series(self._nwbfile, electrical_series_name)
# The indices in the electrode table corresponding to this electrical series
Expand Down Expand Up @@ -374,15 +412,21 @@ def __init__(
else:
self.set_property(property_name, values)

if stream_mode not in ["fsspec", "ros3"]:
file_path = str(Path(file_path).absolute())
if stream_mode not in ["fsspec", "ros3", "remfile"]:
if file_path is not None:
file_path = str(Path(file_path).absolute())
if stream_mode == "fsspec":
# only add stream_cache_path to kwargs if it was passed as an argument
if stream_cache_path is not None:
stream_cache_path = str(Path(self.stream_cache_path).absolute())

self.extra_requirements.extend(["pandas", "pynwb", "hdmf"])
self._electrical_series = electrical_series

# set serializability bools
if file is not None:
# not json serializable if file arg is provided
self._serializability["json"] = False

self._kwargs = {
"file_path": file_path,
"electrical_series_name": self._electrical_series_name,
Expand All @@ -391,6 +435,7 @@ def __init__(
"stream_mode": stream_mode,
"cache": cache,
"stream_cache_path": stream_cache_path,
"file": file,
}


Expand Down
65 changes: 65 additions & 0 deletions src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest
import numpy as np
import h5py
from spikeinterface.core.testing import check_recordings_equal

from spikeinterface.core.testing import check_recordings_equal, check_sortings_equal
from spikeinterface.extractors import NwbRecordingExtractor, NwbSortingExtractor
Expand Down Expand Up @@ -87,6 +88,70 @@ def test_recording_s3_nwb_fsspec(tmp_path, cache):
check_recordings_equal(rec, reloaded_recording)


@pytest.mark.streaming_extractors
def test_recording_s3_nwb_remfile():
file_path = (
"https://dandi-api-staging-dandisets.s3.amazonaws.com/blobs/5f4/b7a/5f4b7a1f-7b95-4ad8-9579-4df6025371cc"
)
rec = NwbRecordingExtractor(file_path, stream_mode="remfile")

start_frame = 0
end_frame = 300
num_frames = end_frame - start_frame

num_seg = rec.get_num_segments()
num_chans = rec.get_num_channels()
dtype = rec.get_dtype()

for segment_index in range(num_seg):
num_samples = rec.get_num_samples(segment_index=segment_index)

full_traces = rec.get_traces(segment_index=segment_index, start_frame=start_frame, end_frame=end_frame)
assert full_traces.shape == (num_frames, num_chans)
assert full_traces.dtype == dtype

if rec.has_scaled():
trace_scaled = rec.get_traces(segment_index=segment_index, return_scaled=True, end_frame=2)
assert trace_scaled.dtype == "float32"


@pytest.mark.streaming_extractors
def test_recording_s3_nwb_remfile_file_like(tmp_path):
import remfile

file_path = (
"https://dandi-api-staging-dandisets.s3.amazonaws.com/blobs/5f4/b7a/5f4b7a1f-7b95-4ad8-9579-4df6025371cc"
)
file = remfile.File(file_path)
rec = NwbRecordingExtractor(file=file)

start_frame = 0
end_frame = 300
num_frames = end_frame - start_frame

num_seg = rec.get_num_segments()
num_chans = rec.get_num_channels()
dtype = rec.get_dtype()

for segment_index in range(num_seg):
num_samples = rec.get_num_samples(segment_index=segment_index)

full_traces = rec.get_traces(segment_index=segment_index, start_frame=start_frame, end_frame=end_frame)
assert full_traces.shape == (num_frames, num_chans)
assert full_traces.dtype == dtype

if rec.has_scaled():
trace_scaled = rec.get_traces(segment_index=segment_index, return_scaled=True, end_frame=2)
assert trace_scaled.dtype == "float32"

# test pickling
with open(tmp_path / "rec.pkl", "wb") as f:
pickle.dump(rec, f)
with open(tmp_path / "rec.pkl", "rb") as f:
rec2 = pickle.load(f)
check_recordings_equal(rec, rec2)


@pytest.mark.ros3_test
@pytest.mark.streaming_extractors
@pytest.mark.skipif("ros3" not in h5py.registered_drivers(), reason="ROS3 driver not installed")
Expand Down

0 comments on commit ee88ef3

Please sign in to comment.