Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support remfile and file-like in nwb rec extractor #2169

Merged
merged 16 commits into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ streaming_extractors = [
"aiohttp",
"requests",
"pynwb>=2.3.0",
"remfile"
]

full = [
Expand Down
79 changes: 62 additions & 17 deletions src/spikeinterface/extractors/nwbextractors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations
from pathlib import Path
from typing import Union, List, Optional, Literal, Dict
from typing import Union, List, Optional, Literal, Dict, BinaryIO

import numpy as np

Expand Down Expand Up @@ -69,8 +69,10 @@ def retrieve_electrical_series(nwbfile: NWBFile, electrical_series_name: Optiona


def read_nwbfile(
file_path: str | Path,
stream_mode: Literal["ffspec", "ros3"] | None = None,
*,
file_path: str | Path | None,
file: BinaryIO | None = None,
stream_mode: Literal["ffspec", "ros3", "remfile"] | None = None,
cache: bool = True,
stream_cache_path: str | Path | None = None,
) -> NWBFile:
Expand All @@ -79,9 +81,11 @@ def read_nwbfile(

Parameters
----------
file_path : Path, str
The path to the NWB file.
stream_mode : "fsspec" or "ros3" or None, default: None
file_path : Path, str or None
The path to the NWB file. Either provide this or file.
file : file-like object or None
The file-like object to read from. Either provide this or file_path.
stream_mode : "fsspec" | "ros3" | "remfile" | None, default: None
The streaming mode to use. If None it assumes the file is on the local disk.
cache: bool, default: True
If True, the file is cached in the file passed to stream_cache_path
Expand Down Expand Up @@ -110,12 +114,19 @@ def read_nwbfile(
"""
from pynwb import NWBHDF5IO

if file_path is not None and file is not None:
raise ValueError("Provide either file_path or file, not both")
if file_path is None and file is None:
raise ValueError("Provide either file_path or file")

if stream_mode == "fsspec":
import fsspec
import h5py

from fsspec.implementations.cached import CachingFileSystem

assert file_path is not None, "file_path must be specified when using stream_mode='fsspec'"

fsspec_file_system = fsspec.filesystem("http")

if cache:
Expand All @@ -134,15 +145,33 @@ def read_nwbfile(
elif stream_mode == "ros3":
import h5py

assert file_path is not None, "file_path must be specified when using stream_mode='ros3'"

drivers = h5py.registered_drivers()
assertion_msg = "ROS3 support not enbabled, use: install -c conda-forge h5py>=3.2 to enable streaming"
assert "ros3" in drivers, assertion_msg
io = NWBHDF5IO(path=file_path, mode="r", load_namespaces=True, driver="ros3")

else:
elif stream_mode == "remfile":
import remfile
import h5py

assert file_path is not None, "file_path must be specified when using stream_mode='remfile'"
rfile = remfile.File(file_path)
h5_file = h5py.File(rfile, "r")
io = NWBHDF5IO(file=h5_file, mode="r", load_namespaces=True)

elif file_path is not None:
file_path = str(Path(file_path).absolute())
io = NWBHDF5IO(path=file_path, mode="r", load_namespaces=True)

else:
import h5py

assert file is not None, "Unexpected, file is None"
alejoe91 marked this conversation as resolved.
Show resolved Hide resolved
h5_file = h5py.File(file, "r")
io = NWBHDF5IO(file=h5_file, mode="r", load_namespaces=True)

nwbfile = io.read()
return nwbfile

Expand All @@ -152,10 +181,12 @@ class NwbRecordingExtractor(BaseRecording):

Parameters
----------
file_path: str or Path
Path to NWB file or s3 url.
file_path: str, Path, or None
Path to NWB file or s3 url (or None if using file instead)
electrical_series_name: str or None, default: None
The name of the ElectricalSeries. Used if multiple ElectricalSeries are present.
file: file-like object or None, default: None
File-like object to read from (if None, file_path must be specified)
load_time_vector: bool, default: False
If True, the time vector is loaded to the recording object.
samples_for_rate_estimation: int, default: 100000
Expand All @@ -167,7 +198,7 @@ class NwbRecordingExtractor(BaseRecording):
If True, the file is cached in the file passed to stream_cache_path
if False, the file is not cached.
stream_cache_path: str or Path or None, default: None
Local path for caching. If None it uses cwd
Local path for caching. If None it uses the current working directory (cwd)

Returns
-------
Expand Down Expand Up @@ -202,27 +233,34 @@ class NwbRecordingExtractor(BaseRecording):

def __init__(
self,
file_path: str | Path,
electrical_series_name: str = None,
file_path: str | Path | None = None, # provide either this or file
electrical_series_name: str | None = None,
load_time_vector: bool = False,
samples_for_rate_estimation: int = 100000,
cache: bool = True,
stream_mode: Optional[Literal["fsspec", "ros3"]] = None,
stream_mode: Optional[Literal["fsspec", "ros3", "remfile"]] = None,
h-mayorquin marked this conversation as resolved.
Show resolved Hide resolved
stream_cache_path: str | Path | None = None,
*,
file: BinaryIO | None = None, # file-like - provide either this or file_path
):
try:
from pynwb import NWBHDF5IO, NWBFile
from pynwb.ecephys import ElectrodeGroup
except ImportError:
raise ImportError(self.installation_mesg)

if file_path is not None and file is not None:
raise ValueError("Provide either file_path or file, not both")
if file_path is None and file is None:
raise ValueError("Provide either file_path or file")

self.stream_mode = stream_mode
self.stream_cache_path = stream_cache_path
self._electrical_series_name = electrical_series_name

self.file_path = file_path
self._nwbfile = read_nwbfile(
file_path=file_path, stream_mode=stream_mode, cache=cache, stream_cache_path=stream_cache_path
file_path=file_path, file=file, stream_mode=stream_mode, cache=cache, stream_cache_path=stream_cache_path
)
electrical_series = retrieve_electrical_series(self._nwbfile, electrical_series_name)
# The indices in the electrode table corresponding to this electrical series
Expand Down Expand Up @@ -374,15 +412,21 @@ def __init__(
else:
self.set_property(property_name, values)

if stream_mode not in ["fsspec", "ros3"]:
file_path = str(Path(file_path).absolute())
if stream_mode not in ["fsspec", "ros3", "remfile"]:
if file_path is not None:
file_path = str(Path(file_path).absolute())
if stream_mode == "fsspec":
# only add stream_cache_path to kwargs if it was passed as an argument
if stream_cache_path is not None:
stream_cache_path = str(Path(self.stream_cache_path).absolute())

self.extra_requirements.extend(["pandas", "pynwb", "hdmf"])
self._electrical_series = electrical_series

# set serializability bools
if file is not None:
# not json serializable if file arg is provided
self._serializability["json"] = False

self._kwargs = {
"file_path": file_path,
"electrical_series_name": self._electrical_series_name,
Expand All @@ -391,6 +435,7 @@ def __init__(
"stream_mode": stream_mode,
"cache": cache,
"stream_cache_path": stream_cache_path,
"file": file,
}


Expand Down
65 changes: 65 additions & 0 deletions src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest
import numpy as np
import h5py
from spikeinterface.core.testing import check_recordings_equal

from spikeinterface.core.testing import check_recordings_equal, check_sortings_equal
from spikeinterface.extractors import NwbRecordingExtractor, NwbSortingExtractor
Expand Down Expand Up @@ -87,6 +88,70 @@ def test_recording_s3_nwb_fsspec(tmp_path, cache):
check_recordings_equal(rec, reloaded_recording)


@pytest.mark.streaming_extractors
def test_recording_s3_nwb_remfile():
file_path = (
"https://dandi-api-staging-dandisets.s3.amazonaws.com/blobs/5f4/b7a/5f4b7a1f-7b95-4ad8-9579-4df6025371cc"
)
rec = NwbRecordingExtractor(file_path, stream_mode="remfile")

start_frame = 0
end_frame = 300
num_frames = end_frame - start_frame

num_seg = rec.get_num_segments()
num_chans = rec.get_num_channels()
dtype = rec.get_dtype()

for segment_index in range(num_seg):
num_samples = rec.get_num_samples(segment_index=segment_index)

full_traces = rec.get_traces(segment_index=segment_index, start_frame=start_frame, end_frame=end_frame)
assert full_traces.shape == (num_frames, num_chans)
assert full_traces.dtype == dtype

if rec.has_scaled():
trace_scaled = rec.get_traces(segment_index=segment_index, return_scaled=True, end_frame=2)
assert trace_scaled.dtype == "float32"


@pytest.mark.streaming_extractors
def test_recording_s3_nwb_remfile_file_like(tmp_path):
import remfile

file_path = (
"https://dandi-api-staging-dandisets.s3.amazonaws.com/blobs/5f4/b7a/5f4b7a1f-7b95-4ad8-9579-4df6025371cc"
)
file = remfile.File(file_path)
rec = NwbRecordingExtractor(file=file)

start_frame = 0
end_frame = 300
num_frames = end_frame - start_frame

num_seg = rec.get_num_segments()
num_chans = rec.get_num_channels()
dtype = rec.get_dtype()

for segment_index in range(num_seg):
num_samples = rec.get_num_samples(segment_index=segment_index)

full_traces = rec.get_traces(segment_index=segment_index, start_frame=start_frame, end_frame=end_frame)
assert full_traces.shape == (num_frames, num_chans)
assert full_traces.dtype == dtype

if rec.has_scaled():
trace_scaled = rec.get_traces(segment_index=segment_index, return_scaled=True, end_frame=2)
assert trace_scaled.dtype == "float32"

# test pickling
with open(tmp_path / "rec.pkl", "wb") as f:
pickle.dump(rec, f)
with open(tmp_path / "rec.pkl", "rb") as f:
rec2 = pickle.load(f)
check_recordings_equal(rec, rec2)


h-mayorquin marked this conversation as resolved.
Show resolved Hide resolved
@pytest.mark.ros3_test
@pytest.mark.streaming_extractors
@pytest.mark.skipif("ros3" not in h5py.registered_drivers(), reason="ROS3 driver not installed")
Expand Down