Skip to content

Commit

Permalink
Merge pull request #2246 from h-mayorquin/add_option_for_non_caching
Browse files Browse the repository at this point in the history
Add option for no caching option to the `NWBRecordingExtractor` when streaming
  • Loading branch information
alejoe91 authored Nov 27, 2023
2 parents d269898 + 506732c commit 1244b8d
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 14 deletions.
38 changes: 27 additions & 11 deletions src/spikeinterface/extractors/nwbextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ def retrieve_electrical_series(nwbfile: NWBFile, electrical_series_name: Optiona
def read_nwbfile(
file_path: str | Path,
stream_mode: Literal["ffspec", "ros3"] | None = None,
stream_cache_path: str | Path | None = None,
cache: bool = True,
stream_cache_path: str | Path | bool = True,
) -> NWBFile:
"""
Read an NWB file and return the NWBFile object.
Expand All @@ -82,9 +83,11 @@ def read_nwbfile(
The path to the NWB file.
stream_mode : "fsspec" or "ros3" or None, default: None
The streaming mode to use. If None it assumes the file is on the local disk.
cache: bool, default: True
If True, the file is cached in the file passed to stream_cache_path
if False, the file is not cached.
stream_cache_path : str or None, default: None
The path to the cache storage
Returns
-------
nwbfile : NWBFile
Expand All @@ -104,21 +107,27 @@ def read_nwbfile(
--------
>>> nwbfile = read_nwbfile("data.nwb", stream_mode="ros3")
"""
from pynwb import NWBHDF5IO, NWBFile
from pynwb import NWBHDF5IO

if stream_mode == "fsspec":
import fsspec
import h5py

from fsspec.implementations.cached import CachingFileSystem

stream_cache_path = stream_cache_path if stream_cache_path is not None else str(get_global_tmp_folder())
caching_file_system = CachingFileSystem(
fs=fsspec.filesystem("http"),
cache_storage=str(stream_cache_path),
)
cached_file = caching_file_system.open(path=file_path, mode="rb")
file = h5py.File(cached_file)
fsspec_file_system = fsspec.filesystem("http")

if cache:
stream_cache_path = stream_cache_path if stream_cache_path is not None else str(get_global_tmp_folder())
caching_file_system = CachingFileSystem(
fs=fsspec_file_system,
cache_storage=str(stream_cache_path),
)
ffspec_file = caching_file_system.open(path=file_path, mode="rb")
else:
ffspec_file = fsspec_file_system.open(file_path, "rb")

file = h5py.File(ffspec_file, "r")
io = NWBHDF5IO(file=file, mode="r", load_namespaces=True)

elif stream_mode == "ros3":
Expand Down Expand Up @@ -153,6 +162,9 @@ class NwbRecordingExtractor(BaseRecording):
Used if "rate" is not specified in the ElectricalSeries.
stream_mode: str or None, default: None
Specify the stream mode: "fsspec" or "ros3".
cache: bool, default: True
If True, the file is cached in the file passed to stream_cache_path
if False, the file is not cached.
stream_cache_path: str or Path or None, default: None
Local path for caching. If None it uses cwd
Expand Down Expand Up @@ -193,6 +205,7 @@ def __init__(
electrical_series_name: str = None,
load_time_vector: bool = False,
samples_for_rate_estimation: int = 100000,
cache: bool = True,
stream_mode: Optional[Literal["fsspec", "ros3"]] = None,
stream_cache_path: str | Path | None = None,
):
Expand All @@ -207,7 +220,9 @@ def __init__(
self._electrical_series_name = electrical_series_name

self.file_path = file_path
self._nwbfile = read_nwbfile(file_path=file_path, stream_mode=stream_mode, stream_cache_path=stream_cache_path)
self._nwbfile = read_nwbfile(
file_path=file_path, stream_mode=stream_mode, cache=cache, stream_cache_path=stream_cache_path
)
electrical_series = retrieve_electrical_series(self._nwbfile, electrical_series_name)
# The indices in the electrode table corresponding to this electrical series
electrodes_indices = electrical_series.electrodes.data[:]
Expand Down Expand Up @@ -373,6 +388,7 @@ def __init__(
"load_time_vector": load_time_vector,
"samples_for_rate_estimation": samples_for_rate_estimation,
"stream_mode": stream_mode,
"cache": cache,
"stream_cache_path": stream_cache_path,
}

Expand Down
10 changes: 7 additions & 3 deletions src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,16 @@ def test_recording_s3_nwb_ros3(tmp_path):
check_recordings_equal(rec, reloaded_recording)


@pytest.mark.streaming_extractors
def test_recording_s3_nwb_fsspec(tmp_path):
@pytest.mark.parametrize("cache", [True, False]) # Test with and without cache
def test_recording_s3_nwb_fsspec(tmp_path, cache):
file_path = (
"https://dandi-api-staging-dandisets.s3.amazonaws.com/blobs/5f4/b7a/5f4b7a1f-7b95-4ad8-9579-4df6025371cc"
)
rec = NwbRecordingExtractor(file_path, stream_mode="fsspec", stream_cache_path=cache_folder)

# Instantiate NwbRecordingExtractor with the cache parameter
rec = NwbRecordingExtractor(
file_path, stream_mode="fsspec", cache=cache, stream_cache_path=tmp_path if cache else None
)

start_frame = 0
end_frame = 300
Expand Down

0 comments on commit 1244b8d

Please sign in to comment.