From bf38ade965fb0f9bb062f0798ec8d4de762b974d Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 23 Nov 2023 14:37:42 +0100 Subject: [PATCH 1/2] add option for no caching to the NWBRecordingExtractor when streaming --- .../extractors/nwbextractors.py | 37 +++++++++++++------ .../extractors/tests/test_nwb_s3_extractor.py | 10 +++-- 2 files changed, 33 insertions(+), 14 deletions(-) diff --git a/src/spikeinterface/extractors/nwbextractors.py b/src/spikeinterface/extractors/nwbextractors.py index 010b22975c..67f7ed6200 100644 --- a/src/spikeinterface/extractors/nwbextractors.py +++ b/src/spikeinterface/extractors/nwbextractors.py @@ -71,7 +71,8 @@ def retrieve_electrical_series(nwbfile: NWBFile, electrical_series_name: Optiona def read_nwbfile( file_path: str | Path, stream_mode: Literal["ffspec", "ros3"] | None = None, - stream_cache_path: str | Path | None = None, + cache: bool = True, + stream_cache_path: str | Path | bool = True, ) -> NWBFile: """ Read an NWB file and return the NWBFile object. @@ -82,9 +83,11 @@ def read_nwbfile( The path to the NWB file. stream_mode : "fsspec" or "ros3" or None, default: None The streaming mode to use. If None it assumes the file is on the local disk. + cache: bool, default: True + If True, the file is cached in the file passed to stream_cache_path + if False, the file is not cached. stream_cache_path : str or None, default: None The path to the cache storage - Returns ------- nwbfile : NWBFile @@ -104,7 +107,7 @@ def read_nwbfile( -------- >>> nwbfile = read_nwbfile("data.nwb", stream_mode="ros3") """ - from pynwb import NWBHDF5IO, NWBFile + from pynwb import NWBHDF5IO if stream_mode == "fsspec": import fsspec @@ -112,13 +115,19 @@ def read_nwbfile( from fsspec.implementations.cached import CachingFileSystem - stream_cache_path = stream_cache_path if stream_cache_path is not None else str(get_global_tmp_folder()) - caching_file_system = CachingFileSystem( - fs=fsspec.filesystem("http"), - cache_storage=str(stream_cache_path), - ) - cached_file = caching_file_system.open(path=file_path, mode="rb") - file = h5py.File(cached_file) + fsspec_file_system = fsspec.filesystem("http") + + if cache: + stream_cache_path = stream_cache_path if stream_cache_path is not None else str(get_global_tmp_folder()) + caching_file_system = CachingFileSystem( + fs=fsspec_file_system, + cache_storage=str(stream_cache_path), + ) + ffspec_file = caching_file_system.open(path=file_path, mode="rb") + else: + ffspec_file = fsspec_file_system.open(file_path, "rb") + + file = h5py.File(ffspec_file, "r") io = NWBHDF5IO(file=file, mode="r", load_namespaces=True) elif stream_mode == "ros3": @@ -153,6 +162,9 @@ class NwbRecordingExtractor(BaseRecording): Used if "rate" is not specified in the ElectricalSeries. stream_mode: str or None, default: None Specify the stream mode: "fsspec" or "ros3". + cache: bool, default: True + If True, the file is cached in the file passed to stream_cache_path + if False, the file is not cached. stream_cache_path: str or Path or None, default: None Local path for caching. If None it uses cwd @@ -193,6 +205,7 @@ def __init__( electrical_series_name: str = None, load_time_vector: bool = False, samples_for_rate_estimation: int = 100000, + cache: bool = True, stream_mode: Optional[Literal["fsspec", "ros3"]] = None, stream_cache_path: str | Path | None = None, ): @@ -207,7 +220,9 @@ def __init__( self._electrical_series_name = electrical_series_name self.file_path = file_path - self._nwbfile = read_nwbfile(file_path=file_path, stream_mode=stream_mode, stream_cache_path=stream_cache_path) + self._nwbfile = read_nwbfile( + file_path=file_path, stream_mode=stream_mode, cache=cache, stream_cache_path=stream_cache_path + ) electrical_series = retrieve_electrical_series(self._nwbfile, electrical_series_name) # The indices in the electrode table corresponding to this electrical series electrodes_indices = electrical_series.electrodes.data[:] diff --git a/src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py b/src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py index 253ca2e4ce..0ce81a6218 100644 --- a/src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py +++ b/src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py @@ -52,12 +52,16 @@ def test_recording_s3_nwb_ros3(tmp_path): check_recordings_equal(rec, reloaded_recording) -@pytest.mark.streaming_extractors -def test_recording_s3_nwb_fsspec(tmp_path): +@pytest.mark.parametrize("cache", [True, False]) # Test with and without cache +def test_recording_s3_nwb_fsspec(tmp_path, cache): file_path = ( "https://dandi-api-staging-dandisets.s3.amazonaws.com/blobs/5f4/b7a/5f4b7a1f-7b95-4ad8-9579-4df6025371cc" ) - rec = NwbRecordingExtractor(file_path, stream_mode="fsspec", stream_cache_path=cache_folder) + + # Instantiate NwbRecordingExtractor with the cache parameter + rec = NwbRecordingExtractor( + file_path, stream_mode="fsspec", cache=cache, stream_cache_path=tmp_path if cache else None + ) start_frame = 0 end_frame = 300 From 0a469aca31053085a10ce47a75a9825e0926a903 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 23 Nov 2023 15:28:10 +0100 Subject: [PATCH 2/2] kwargs --- src/spikeinterface/extractors/nwbextractors.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/extractors/nwbextractors.py b/src/spikeinterface/extractors/nwbextractors.py index 67f7ed6200..d3118712ef 100644 --- a/src/spikeinterface/extractors/nwbextractors.py +++ b/src/spikeinterface/extractors/nwbextractors.py @@ -388,6 +388,7 @@ def __init__( "load_time_vector": load_time_vector, "samples_for_rate_estimation": samples_for_rate_estimation, "stream_mode": stream_mode, + "cache": cache, "stream_cache_path": stream_cache_path, }