Skip to content

Commit

Permalink
Merge pull request #2248 from h-mayorquin/add_option_for_no_caching_s…
Browse files Browse the repository at this point in the history
…orting

Add option for no caching sorting for `NWBSortingExtractor`
  • Loading branch information
alejoe91 authored Nov 27, 2023
2 parents 1244b8d + 59da3f8 commit 1ee7053
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 43 deletions.
39 changes: 15 additions & 24 deletions src/spikeinterface/extractors/nwbextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def read_nwbfile(
file_path: str | Path,
stream_mode: Literal["ffspec", "ros3"] | None = None,
cache: bool = True,
stream_cache_path: str | Path | bool = True,
stream_cache_path: str | Path | None = None,
) -> NWBFile:
"""
Read an NWB file and return the NWBFile object.
Expand All @@ -87,7 +87,8 @@ def read_nwbfile(
If True, the file is cached in the file passed to stream_cache_path
if False, the file is not cached.
stream_cache_path : str or None, default: None
The path to the cache storage
The path to the cache storage, when default to None it uses the a temporary
folder.
Returns
-------
nwbfile : NWBFile
Expand Down Expand Up @@ -449,8 +450,11 @@ class NwbSortingExtractor(BaseSorting):
Used if "rate" is not specified in the ElectricalSeries.
stream_mode: str or None, default: None
Specify the stream mode: "fsspec" or "ros3".
cache: bool, default: True
If True, the file is cached in the file passed to stream_cache_path
if False, the file is not cached.
stream_cache_path: str or Path or None, default: None
Local path for caching. If None it uses cwd
Local path for caching. If None it uses the system temporary directory.
Returns
-------
Expand All @@ -470,6 +474,7 @@ def __init__(
sampling_frequency: float | None = None,
samples_for_rate_estimation: int = 100000,
stream_mode: str | None = None,
cache: bool = True,
stream_cache_path: str | Path | None = None,
):
try:
Expand All @@ -483,27 +488,10 @@ def __init__(
self._electrical_series_name = electrical_series_name

self.file_path = file_path
if stream_mode == "fsspec":
import fsspec
from fsspec.implementations.cached import CachingFileSystem
import h5py

self.stream_cache_path = stream_cache_path if stream_cache_path is not None else "cache"
self.cfs = CachingFileSystem(
fs=fsspec.filesystem("http"),
cache_storage=str(self.stream_cache_path),
)
file_path_ = self.cfs.open(file_path, "rb")
file = h5py.File(file_path_)
self.io = NWBHDF5IO(file=file, mode="r", load_namespaces=True)

elif stream_mode == "ros3":
self.io = NWBHDF5IO(file_path, mode="r", load_namespaces=True, driver="ros3")
else:
file_path_ = str(Path(file_path).absolute())
self.io = NWBHDF5IO(file_path_, mode="r", load_namespaces=True)
self._nwbfile = read_nwbfile(
file_path=file_path, stream_mode=stream_mode, cache=cache, stream_cache_path=stream_cache_path
)

self._nwbfile = self.io.read()
units_ids = list(self._nwbfile.units.id[:])

timestamps = None
Expand Down Expand Up @@ -561,12 +549,15 @@ def __init__(
if stream_mode not in ["fsspec", "ros3"]:
file_path = str(Path(file_path).absolute())
if stream_mode == "fsspec":
stream_cache_path = str(Path(self.stream_cache_path).absolute())
# only add stream_cache_path to kwargs if it was passed as an argument
if stream_cache_path is not None:
stream_cache_path = str(Path(self.stream_cache_path).absolute())
self._kwargs = {
"file_path": file_path,
"electrical_series_name": self._electrical_series_name,
"sampling_frequency": sampling_frequency,
"samples_for_rate_estimation": samples_for_rate_estimation,
"cache": cache,
"stream_mode": stream_mode,
"stream_cache_path": stream_cache_path,
}
Expand Down
36 changes: 17 additions & 19 deletions src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,6 @@
from spikeinterface.core.testing import check_recordings_equal, check_sortings_equal
from spikeinterface.extractors import NwbRecordingExtractor, NwbSortingExtractor

if hasattr(pytest, "global_test_folder"):
cache_folder = pytest.global_test_folder / "extractors"
else:
cache_folder = Path("cache_folder") / "extractors"


@pytest.mark.ros3_test
@pytest.mark.streaming_extractors
Expand Down Expand Up @@ -125,35 +120,38 @@ def test_sorting_s3_nwb_ros3(tmp_path):


@pytest.mark.streaming_extractors
def test_sorting_s3_nwb_fsspec(tmp_path):
@pytest.mark.parametrize("cache", [True, False]) # Test with and without cache
def test_sorting_s3_nwb_fsspec(tmp_path, cache):
file_path = "https://dandiarchive.s3.amazonaws.com/blobs/84b/aa4/84baa446-cf19-43e8-bdeb-fc804852279b"
# we provide the 'sampling_frequency' because the NWB file does not the electrical series
sort = NwbSortingExtractor(
file_path, sampling_frequency=30000, stream_mode="fsspec", stream_cache_path=cache_folder
# We provide the 'sampling_frequency' because the NWB file does not have the electrical series
sorting = NwbSortingExtractor(
file_path,
sampling_frequency=30000.0,
stream_mode="fsspec",
cache=cache,
stream_cache_path=tmp_path if cache else None,
)

start_frame = 0
end_frame = 300
num_frames = end_frame - start_frame

num_seg = sort.get_num_segments()
num_units = len(sort.unit_ids)
num_seg = sorting.get_num_segments()
assert num_seg == 1
num_units = len(sorting.unit_ids)
assert num_units == 64

for segment_index in range(num_seg):
for unit in sort.unit_ids:
spike_train = sort.get_unit_spike_train(unit_id=unit, segment_index=segment_index)
for unit in sorting.unit_ids:
spike_train = sorting.get_unit_spike_train(unit_id=unit, segment_index=segment_index)
assert len(spike_train) > 0
assert spike_train.dtype == "int64"
assert np.all(spike_train >= 0)

tmp_file = tmp_path / "test_fsspec_sorting.pkl"
with open(tmp_file, "wb") as f:
pickle.dump(sort, f)
pickle.dump(sorting, f)

with open(tmp_file, "rb") as f:
reloaded_sorting = pickle.load(f)

check_sortings_equal(reloaded_sorting, sort)
check_sortings_equal(reloaded_sorting, sorting)


if __name__ == "__main__":
Expand Down

0 comments on commit 1ee7053

Please sign in to comment.