diff --git a/pyproject.toml b/pyproject.toml index f864c215e6..e105bd6a85 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,6 +79,7 @@ streaming_extractors = [ "aiohttp", "requests", "pynwb>=2.3.0", + "remfile" ] full = [ diff --git a/src/spikeinterface/extractors/nwbextractors.py b/src/spikeinterface/extractors/nwbextractors.py index fd003390e3..c2e624957a 100644 --- a/src/spikeinterface/extractors/nwbextractors.py +++ b/src/spikeinterface/extractors/nwbextractors.py @@ -1,6 +1,6 @@ from __future__ import annotations from pathlib import Path -from typing import Union, List, Optional, Literal, Dict +from typing import Union, List, Optional, Literal, Dict, BinaryIO import numpy as np @@ -69,8 +69,10 @@ def retrieve_electrical_series(nwbfile: NWBFile, electrical_series_name: Optiona def read_nwbfile( - file_path: str | Path, - stream_mode: Literal["ffspec", "ros3"] | None = None, + *, + file_path: str | Path | None, + file: BinaryIO | None = None, + stream_mode: Literal["ffspec", "ros3", "remfile"] | None = None, cache: bool = True, stream_cache_path: str | Path | None = None, ) -> NWBFile: @@ -79,9 +81,11 @@ def read_nwbfile( Parameters ---------- - file_path : Path, str - The path to the NWB file. - stream_mode : "fsspec" or "ros3" or None, default: None + file_path : Path, str or None + The path to the NWB file. Either provide this or file. + file : file-like object or None + The file-like object to read from. Either provide this or file_path. + stream_mode : "fsspec" | "ros3" | "remfile" | None, default: None The streaming mode to use. If None it assumes the file is on the local disk. cache: bool, default: True If True, the file is cached in the file passed to stream_cache_path @@ -110,12 +114,19 @@ def read_nwbfile( """ from pynwb import NWBHDF5IO + if file_path is not None and file is not None: + raise ValueError("Provide either file_path or file, not both") + if file_path is None and file is None: + raise ValueError("Provide either file_path or file") + if stream_mode == "fsspec": import fsspec import h5py from fsspec.implementations.cached import CachingFileSystem + assert file_path is not None, "file_path must be specified when using stream_mode='fsspec'" + fsspec_file_system = fsspec.filesystem("http") if cache: @@ -134,15 +145,33 @@ def read_nwbfile( elif stream_mode == "ros3": import h5py + assert file_path is not None, "file_path must be specified when using stream_mode='ros3'" + drivers = h5py.registered_drivers() assertion_msg = "ROS3 support not enbabled, use: install -c conda-forge h5py>=3.2 to enable streaming" assert "ros3" in drivers, assertion_msg io = NWBHDF5IO(path=file_path, mode="r", load_namespaces=True, driver="ros3") - else: + elif stream_mode == "remfile": + import remfile + import h5py + + assert file_path is not None, "file_path must be specified when using stream_mode='remfile'" + rfile = remfile.File(file_path) + h5_file = h5py.File(rfile, "r") + io = NWBHDF5IO(file=h5_file, mode="r", load_namespaces=True) + + elif file_path is not None: file_path = str(Path(file_path).absolute()) io = NWBHDF5IO(path=file_path, mode="r", load_namespaces=True) + else: + import h5py + + assert file is not None, "Unexpected, file is None" + h5_file = h5py.File(file, "r") + io = NWBHDF5IO(file=h5_file, mode="r", load_namespaces=True) + nwbfile = io.read() return nwbfile @@ -152,10 +181,12 @@ class NwbRecordingExtractor(BaseRecording): Parameters ---------- - file_path: str or Path - Path to NWB file or s3 url. + file_path: str, Path, or None + Path to NWB file or s3 url (or None if using file instead) electrical_series_name: str or None, default: None The name of the ElectricalSeries. Used if multiple ElectricalSeries are present. + file: file-like object or None, default: None + File-like object to read from (if None, file_path must be specified) load_time_vector: bool, default: False If True, the time vector is loaded to the recording object. samples_for_rate_estimation: int, default: 100000 @@ -167,7 +198,7 @@ class NwbRecordingExtractor(BaseRecording): If True, the file is cached in the file passed to stream_cache_path if False, the file is not cached. stream_cache_path: str or Path or None, default: None - Local path for caching. If None it uses cwd + Local path for caching. If None it uses the current working directory (cwd) Returns ------- @@ -202,13 +233,15 @@ class NwbRecordingExtractor(BaseRecording): def __init__( self, - file_path: str | Path, - electrical_series_name: str = None, + file_path: str | Path | None = None, # provide either this or file + electrical_series_name: str | None = None, load_time_vector: bool = False, samples_for_rate_estimation: int = 100000, cache: bool = True, - stream_mode: Optional[Literal["fsspec", "ros3"]] = None, + stream_mode: Optional[Literal["fsspec", "ros3", "remfile"]] = None, stream_cache_path: str | Path | None = None, + *, + file: BinaryIO | None = None, # file-like - provide either this or file_path ): try: from pynwb import NWBHDF5IO, NWBFile @@ -216,13 +249,18 @@ def __init__( except ImportError: raise ImportError(self.installation_mesg) + if file_path is not None and file is not None: + raise ValueError("Provide either file_path or file, not both") + if file_path is None and file is None: + raise ValueError("Provide either file_path or file") + self.stream_mode = stream_mode self.stream_cache_path = stream_cache_path self._electrical_series_name = electrical_series_name self.file_path = file_path self._nwbfile = read_nwbfile( - file_path=file_path, stream_mode=stream_mode, cache=cache, stream_cache_path=stream_cache_path + file_path=file_path, file=file, stream_mode=stream_mode, cache=cache, stream_cache_path=stream_cache_path ) electrical_series = retrieve_electrical_series(self._nwbfile, electrical_series_name) # The indices in the electrode table corresponding to this electrical series @@ -374,15 +412,21 @@ def __init__( else: self.set_property(property_name, values) - if stream_mode not in ["fsspec", "ros3"]: - file_path = str(Path(file_path).absolute()) + if stream_mode not in ["fsspec", "ros3", "remfile"]: + if file_path is not None: + file_path = str(Path(file_path).absolute()) if stream_mode == "fsspec": - # only add stream_cache_path to kwargs if it was passed as an argument if stream_cache_path is not None: stream_cache_path = str(Path(self.stream_cache_path).absolute()) self.extra_requirements.extend(["pandas", "pynwb", "hdmf"]) self._electrical_series = electrical_series + + # set serializability bools + if file is not None: + # not json serializable if file arg is provided + self._serializability["json"] = False + self._kwargs = { "file_path": file_path, "electrical_series_name": self._electrical_series_name, @@ -391,6 +435,7 @@ def __init__( "stream_mode": stream_mode, "cache": cache, "stream_cache_path": stream_cache_path, + "file": file, } diff --git a/src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py b/src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py index 81d7decf50..ce05dced19 100644 --- a/src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py +++ b/src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py @@ -4,6 +4,7 @@ import pytest import numpy as np import h5py +from spikeinterface.core.testing import check_recordings_equal from spikeinterface.core.testing import check_recordings_equal, check_sortings_equal from spikeinterface.extractors import NwbRecordingExtractor, NwbSortingExtractor @@ -87,6 +88,70 @@ def test_recording_s3_nwb_fsspec(tmp_path, cache): check_recordings_equal(rec, reloaded_recording) +@pytest.mark.streaming_extractors +def test_recording_s3_nwb_remfile(): + file_path = ( + "https://dandi-api-staging-dandisets.s3.amazonaws.com/blobs/5f4/b7a/5f4b7a1f-7b95-4ad8-9579-4df6025371cc" + ) + rec = NwbRecordingExtractor(file_path, stream_mode="remfile") + + start_frame = 0 + end_frame = 300 + num_frames = end_frame - start_frame + + num_seg = rec.get_num_segments() + num_chans = rec.get_num_channels() + dtype = rec.get_dtype() + + for segment_index in range(num_seg): + num_samples = rec.get_num_samples(segment_index=segment_index) + + full_traces = rec.get_traces(segment_index=segment_index, start_frame=start_frame, end_frame=end_frame) + assert full_traces.shape == (num_frames, num_chans) + assert full_traces.dtype == dtype + + if rec.has_scaled(): + trace_scaled = rec.get_traces(segment_index=segment_index, return_scaled=True, end_frame=2) + assert trace_scaled.dtype == "float32" + + +@pytest.mark.streaming_extractors +def test_recording_s3_nwb_remfile_file_like(tmp_path): + import remfile + + file_path = ( + "https://dandi-api-staging-dandisets.s3.amazonaws.com/blobs/5f4/b7a/5f4b7a1f-7b95-4ad8-9579-4df6025371cc" + ) + file = remfile.File(file_path) + rec = NwbRecordingExtractor(file=file) + + start_frame = 0 + end_frame = 300 + num_frames = end_frame - start_frame + + num_seg = rec.get_num_segments() + num_chans = rec.get_num_channels() + dtype = rec.get_dtype() + + for segment_index in range(num_seg): + num_samples = rec.get_num_samples(segment_index=segment_index) + + full_traces = rec.get_traces(segment_index=segment_index, start_frame=start_frame, end_frame=end_frame) + assert full_traces.shape == (num_frames, num_chans) + assert full_traces.dtype == dtype + + if rec.has_scaled(): + trace_scaled = rec.get_traces(segment_index=segment_index, return_scaled=True, end_frame=2) + assert trace_scaled.dtype == "float32" + + # test pickling + with open(tmp_path / "rec.pkl", "wb") as f: + pickle.dump(rec, f) + with open(tmp_path / "rec.pkl", "rb") as f: + rec2 = pickle.load(f) + check_recordings_equal(rec, rec2) + + @pytest.mark.ros3_test @pytest.mark.streaming_extractors @pytest.mark.skipif("ros3" not in h5py.registered_drivers(), reason="ROS3 driver not installed")