From 7e7824e2164a6ab1badd6a946dcd6c7a91539fe9 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 20 Dec 2023 14:39:16 +0100 Subject: [PATCH 01/11] Start SharedMemoryRecording --- src/spikeinterface/core/__init__.py | 2 +- src/spikeinterface/core/baserecording.py | 2 +- src/spikeinterface/core/numpyextractors.py | 110 +++++++++++++++++- src/spikeinterface/core/recording_tools.py | 16 ++- .../core/tests/test_numpy_extractors.py | 26 ++++- .../core/tests/test_recording_tools.py | 16 +-- 6 files changed, 154 insertions(+), 18 deletions(-) diff --git a/src/spikeinterface/core/__init__.py b/src/spikeinterface/core/__init__.py index 4b9fedcd6f..343b98db4a 100644 --- a/src/spikeinterface/core/__init__.py +++ b/src/spikeinterface/core/__init__.py @@ -8,7 +8,7 @@ # main extractor from dump and cache from .binaryrecordingextractor import BinaryRecordingExtractor, read_binary from .npzsortingextractor import NpzSortingExtractor, read_npz_sorting -from .numpyextractors import NumpyRecording, NumpySorting, SharedMemorySorting, NumpyEvent, NumpySnippets +from .numpyextractors import NumpyRecording, SharedMemoryRecording, NumpySorting, SharedMemorySorting, NumpyEvent, NumpySnippets from .zarrextractors import ZarrRecordingExtractor, ZarrSortingExtractor, read_zarr, get_default_zarr_compressor from .binaryfolder import BinaryFolderRecording, read_binary_folder from .sortingfolder import NumpyFolderSorting, NpzFolderSorting, read_numpy_sorting_folder, read_npz_folder diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index e2e692ded9..7892fd441c 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -490,7 +490,7 @@ def _save(self, format="binary", **save_kwargs): cached = BinaryFolderRecording(folder_path=folder) elif format == "memory": - traces_list = write_memory_recording(self, dtype=None, **job_kwargs) + traces_list, shms = write_memory_recording(self, dtype=None, **job_kwargs) from .numpyextractors import NumpyRecording cached = NumpyRecording( diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index 88db042512..11a5c8a4f1 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -13,9 +13,10 @@ ) from .basesorting import minimum_spike_dtype from .core_tools import make_shared_array - +from .recording_tools import write_memory_recording from multiprocessing.shared_memory import SharedMemory + from typing import Union @@ -99,6 +100,113 @@ def get_traces(self, start_frame, end_frame, channel_indices): traces = traces[:, channel_indices] return traces + + + +class SharedMemoryRecording(BaseRecording): + """ + In memory recording with shared memmory buffer. + + + Parameters + ---------- + shm_names: list + List of sharedmem names. + shape_list: list + + sampling_frequency: float + The sampling frequency in Hz + t_starts: None or list of float + Times in seconds of the first sample for each segment + channel_ids: list + An optional list of channel_ids. If None, linear channels are assumed + """ + + extractor_name = "SharedMemory" + mode = "memory" + name = "SharedMemory" + + def __init__( + self, shm_names, shape_list, dtype, sampling_frequency, channel_ids=None, t_starts=None, main_shm_owner=True + ): + + + assert len(shape_list) == len(shm_names) + assert all(shape_list[0][1] == shape[1] for shape in shape_list) + + # create traces from shraedmem names + self.shms = [] + traces_list = [] + for shm_name, shape in zip(shm_names, shape_list): + shm = SharedMemory(shm_name, create=False) + self.shms.append(shm) + traces = np.ndarray(shape=shape, dtype=dtype, buffer=shm.buf) + traces_list.append(traces) + + if channel_ids is None: + channel_ids = np.arange(traces_list[0].shape[1]) + else: + channel_ids = np.asarray(channel_ids) + assert channel_ids.size == traces_list[0].shape[1] + BaseRecording.__init__(self, sampling_frequency, channel_ids, dtype) + + if t_starts is not None: + assert len(t_starts) == len(traces_list), "t_starts must be a list of same size than traces_list" + t_starts = [float(t_start) for t_start in t_starts] + + self._serializability["memory"] = True + self._serializability["json"] = False + self._serializability["pickle"] = False + + # this is important so that the main owner can unlink the share mem buffer : + self.main_shm_owner = main_shm_owner + + for i, traces, in enumerate(traces_list): + if t_starts is None: + t_start = None + else: + t_start = t_starts[i] + rec_segment = NumpyRecordingSegment(traces, sampling_frequency, t_start) + + self.add_recording_segment(rec_segment) + + self._kwargs = { + "shm_names": shm_names, + "shape_list": shape_list, + "dtype": dtype, + "sampling_frequency": sampling_frequency, + "channel_ids": channel_ids, + "t_starts": t_starts, + # this is important so that clone of this will not try to unlink the share mem buffer : + "main_shm_owner": False, + } + + + def __del__(self): + self._recording_segments =[] + for shm in self.shms: + shm.close() + if self.main_shm_owner: + shm.unlink() + + @staticmethod + def from_recording(source_recording, **job_kwargs): + traces_list, shms = write_memory_recording(source_recording, buffer_type="sharedmem", **job_kwargs) + + recording = SharedMemoryRecording( + shm_names=[shm.name for shm in shms], + shape_list = [traces.shape for traces in traces_list], + dtype=source_recording.dtype, + sampling_frequency=source_recording.sampling_frequency, + channel_ids=source_recording.channel_ids, + t_starts=None, + main_shm_owner=True + ) + + for shm in shms: + # the sharedmem are handle by the new SharedMemoryRecording + shm.close() + return recording class NumpySorting(BaseSorting): diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index d3aabf657c..fb0644c755 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -269,7 +269,7 @@ def _write_memory_chunk(segment_index, start_frame, end_frame, worker_ctx): arr[start_frame:end_frame, :] = traces -def write_memory_recording(recording, dtype=None, verbose=False, auto_cast_uint=True, **job_kwargs): +def write_memory_recording(recording, dtype=None, verbose=False, auto_cast_uint=True, buffer_type="auto", **job_kwargs): """ Save the traces into numpy arrays (memory). try to use the SharedMemory introduce in py3.8 if n_jobs > 1 @@ -284,6 +284,7 @@ def write_memory_recording(recording, dtype=None, verbose=False, auto_cast_uint= If True, output is verbose (when chunks are used) auto_cast_uint: bool, default: True If True, unsigned integers are automatically cast to int if the specified dtype is signed + buffer_type: "auto" | "numpy" | "sharedmem" {} Returns @@ -302,19 +303,28 @@ def write_memory_recording(recording, dtype=None, verbose=False, auto_cast_uint= # create sharedmmep arrays = [] shm_names = [] + shms = [] shapes = [] n_jobs = ensure_n_jobs(recording, n_jobs=job_kwargs.get("n_jobs", 1)) + if buffer_type == "auto": + if n_jobs > 1: + buffer_type = "sharedmem" + else: + buffer_type = "numpy" + for segment_index in range(recording.get_num_segments()): num_frames = recording.get_num_samples(segment_index) num_channels = recording.get_num_channels() shape = (num_frames, num_channels) shapes.append(shape) - if n_jobs > 1: + if buffer_type == "sharedmem": arr, shm = make_shared_array(shape, dtype) shm_names.append(shm.name) + shms.append(shm) else: arr = np.zeros(shape, dtype=dtype) + shms.append(None) arrays.append(arr) # use executor (loop or workers) @@ -330,7 +340,7 @@ def write_memory_recording(recording, dtype=None, verbose=False, auto_cast_uint= ) executor.run() - return arrays + return arrays, shms write_memory_recording.__doc__ = write_memory_recording.__doc__.format(_shared_job_kwargs_doc) diff --git a/src/spikeinterface/core/tests/test_numpy_extractors.py b/src/spikeinterface/core/tests/test_numpy_extractors.py index 4a5bffbc05..8a31783cb7 100644 --- a/src/spikeinterface/core/tests/test_numpy_extractors.py +++ b/src/spikeinterface/core/tests/test_numpy_extractors.py @@ -4,9 +4,9 @@ import pytest import numpy as np -from spikeinterface.core import NumpyRecording, NumpySorting, SharedMemorySorting, NumpyEvent -from spikeinterface.core import create_sorting_npz, load_extractor -from spikeinterface.core import NpzSortingExtractor +from spikeinterface.core import (BaseRecording, NumpyRecording, SharedMemoryRecording, NumpySorting, SharedMemorySorting, +NumpyEvent, create_sorting_npz, load_extractor, NpzSortingExtractor, generate_recording) + from spikeinterface.core.basesorting import minimum_spike_dtype if hasattr(pytest, "global_test_folder"): @@ -29,6 +29,23 @@ def test_NumpyRecording(): rec.save(folder=cache_folder / "test_NumpyRecording") +def test_SharedMemoryRecording(): + rec0 = generate_recording(num_channels=2, durations=[4., 3.]) + print(rec0) + job_kwargs = dict(n_jobs=1, progress_bar=True) + rec = SharedMemoryRecording.from_recording(rec0, **job_kwargs) + + d = rec.to_dict() + rec_clone = load_extractor(d) + traces = rec_clone.get_traces(start_frame=0, end_frame=30000, segment_index=0) + + + assert rec.shms[0].name == rec_clone.shms[0].name + + del traces + del rec_clone + del rec + def test_NumpySorting(): sampling_frequency = 30000 @@ -132,6 +149,7 @@ def test_NumpyEvent(): if __name__ == "__main__": # test_NumpyRecording() - test_NumpySorting() + test_SharedMemoryRecording() + # test_NumpySorting() # test_SharedMemorySorting() # test_NumpyEvent() diff --git a/src/spikeinterface/core/tests/test_recording_tools.py b/src/spikeinterface/core/tests/test_recording_tools.py index 7cfc4239b6..6f3c2938c7 100644 --- a/src/spikeinterface/core/tests/test_recording_tools.py +++ b/src/spikeinterface/core/tests/test_recording_tools.py @@ -149,16 +149,16 @@ def test_write_memory_recording(): recording = recording.save() # write with loop - write_memory_recording(recording, dtype=None, verbose=True, n_jobs=1) + traces_list, shms = write_memory_recording(recording, dtype=None, verbose=True, n_jobs=1) - write_memory_recording(recording, dtype=None, verbose=True, n_jobs=1, chunk_memory="100k", progress_bar=True) + traces_list, shms = write_memory_recording(recording, dtype=None, verbose=True, n_jobs=1, chunk_memory="100k", progress_bar=True) - if platform.system() != "Windows": - # write parrallel - write_memory_recording(recording, dtype=None, verbose=False, n_jobs=2, chunk_memory="100k") - - # write parrallel - write_memory_recording(recording, dtype=None, verbose=False, n_jobs=2, total_memory="200k", progress_bar=True) + # write parrallel + traces_list, shms = write_memory_recording(recording, dtype=None, verbose=False, n_jobs=2, chunk_memory="100k") + # need to clean the buffer + del traces_list + for shm in shms: + shm.unlink() def test_get_random_data_chunks(): From 8b34c69ec4304e0cbc1b950b8108d292823ae18f Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 20 Dec 2023 21:32:30 +0100 Subject: [PATCH 02/11] More test for sharemem + some fix around save(format=XXX) --- src/spikeinterface/core/base.py | 18 +++++++++++++++--- src/spikeinterface/core/baserecording.py | 9 ++++----- src/spikeinterface/core/basesorting.py | 5 +++++ .../core/tests/test_baserecording.py | 10 ++++++++-- src/spikeinterface/core/zarrextractors.py | 2 ++ 5 files changed, 34 insertions(+), 10 deletions(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 5e8b8f9f1b..ef41875da7 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -839,6 +839,8 @@ def save(self, **kwargs) -> "BaseExtractor": format = kwargs.get("format", None) if format == "memory": loaded_extractor = self.save_to_memory(**kwargs) + elif format == "sharedmemory": + loaded_extractor = self.save_to_sharedmemory(**kwargs) elif format == "zarr": loaded_extractor = self.save_to_zarr(**kwargs) else: @@ -848,8 +850,16 @@ def save(self, **kwargs) -> "BaseExtractor": save.__doc__ = save.__doc__.format(_shared_job_kwargs_doc) def save_to_memory(self, **kwargs) -> "BaseExtractor": - # used only by recording at the moment - cached = self._save(**kwargs) + + kwargs.pop("format", None) + + cached = self._save(format="memory", **kwargs) + self.copy_metadata(cached) + return cached + + def save_to_sharedmemory(self, **kwargs) -> "BaseExtractor": + kwargs.pop("format", None) + cached = self._save(format="sharedmemory", **kwargs) self.copy_metadata(cached) return cached @@ -978,6 +988,8 @@ def save_to_zarr( import zarr from .zarrextractors import read_zarr + save_kwargs.pop("format", None) + if folder is None: cache_folder = get_global_tmp_folder() if name is None: @@ -1006,7 +1018,7 @@ def save_to_zarr( save_kwargs["zarr_path"] = zarr_path save_kwargs["storage_options"] = storage_options save_kwargs["channel_chunk_size"] = channel_chunk_size - cached = self._save(verbose=verbose, **save_kwargs) + cached = self._save(format="zarr", verbose=verbose, **save_kwargs) cached = read_zarr(zarr_path) return cached diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 7892fd441c..aec1b32da9 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -441,11 +441,6 @@ def time_to_sample_index(self, time_s, segment_index=None): return rs.time_to_sample_index(time_s) def _save(self, format="binary", **save_kwargs): - """ - This function replaces the old CacheRecordingExtractor, but enables more engines - for caching a results. At the moment only "binary" with memmap is supported. - We plan to add other engines, such as zarr and NWB. - """ # handle t_starts t_starts = [] @@ -490,12 +485,16 @@ def _save(self, format="binary", **save_kwargs): cached = BinaryFolderRecording(folder_path=folder) elif format == "memory": + # TODO copy traces list to standard numpy to avoid shms not handled traces_list, shms = write_memory_recording(self, dtype=None, **job_kwargs) from .numpyextractors import NumpyRecording cached = NumpyRecording( traces_list, self.get_sampling_frequency(), t_starts=t_starts, channel_ids=self.channel_ids ) + elif format == "sharedmemory": + from .numpyextractors import SharedMemoryRecording + cached = SharedMemoryRecording.from_recording(self, **job_kwargs) elif format == "zarr": from .zarrextractors import ZarrRecordingExtractor diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 465807cc98..23a4962623 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -267,6 +267,11 @@ def _save(self, format="numpy_folder", **save_kwargs): from .numpyextractors import NumpySorting cached = NumpySorting.from_sorting(self) + + elif format == "sharedmemory": + from .numpyextractors import SharedMemorySorting + cached = SharedMemorySorting.from_sorting(self) + else: raise ValueError(f"format {format} not supported") return cached diff --git a/src/spikeinterface/core/tests/test_baserecording.py b/src/spikeinterface/core/tests/test_baserecording.py index 4326cd15aa..5f92a1f732 100644 --- a/src/spikeinterface/core/tests/test_baserecording.py +++ b/src/spikeinterface/core/tests/test_baserecording.py @@ -163,14 +163,20 @@ def test_BaseRecording(): # cache to memory rec4 = rec3.save(format="memory") - traces4 = rec4.get_traces(segment_index=0) traces = rec.get_traces(segment_index=0) assert np.array_equal(traces4, traces) + # cache to sharedmemory + rec5 = rec3.save(format="sharedmemory") + traces5 = rec5.get_traces(segment_index=0) + traces = rec.get_traces(segment_index=0) + assert np.array_equal(traces5, traces) + + # cache joblib several jobs folder = cache_folder / "simple_recording2" - rec2 = rec.save(folder=folder, chunk_size=10, n_jobs=4) + rec2 = rec.save(format="binary", folder=folder, chunk_size=10, n_jobs=4) traces2 = rec2.get_traces(segment_index=0) # set/get Probe only 2 channels diff --git a/src/spikeinterface/core/zarrextractors.py b/src/spikeinterface/core/zarrextractors.py index 32ab5f542a..91e89455d9 100644 --- a/src/spikeinterface/core/zarrextractors.py +++ b/src/spikeinterface/core/zarrextractors.py @@ -249,6 +249,8 @@ def read_zarr( extractor: ZarrExtractor The loaded extractor """ + # TODO @alessio : we should have something more explicit in our zarr format to tell which object it is. + # for the futur SortingResult we will have this 2 fields!!! root = zarr.open(str(folder_path), mode="r", storage_options=storage_options) if "channel_ids" in root.keys(): return read_zarr_recording(folder_path, storage_options=storage_options) From 043f5cebd8c09aa51eb79b325fef412d7d980122 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 20 Dec 2023 21:33:51 +0100 Subject: [PATCH 03/11] oups --- src/spikeinterface/core/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/core/__init__.py b/src/spikeinterface/core/__init__.py index 343b98db4a..3d73e80cbe 100644 --- a/src/spikeinterface/core/__init__.py +++ b/src/spikeinterface/core/__init__.py @@ -86,7 +86,6 @@ from .recording_tools import ( write_binary_recording, write_to_h5_dataset_format, - write_binary_recording, get_random_data_chunks, get_channel_distances, get_closest_channels, From 0b8f743ba96a7804f0400c5b8aeb8a4dc06b821c Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 20 Dec 2023 21:40:31 +0100 Subject: [PATCH 04/11] More protection in save(format="memory") to avoid sharedmem --- src/spikeinterface/core/baserecording.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index aec1b32da9..74855ba1bb 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -485,13 +485,21 @@ def _save(self, format="binary", **save_kwargs): cached = BinaryFolderRecording(folder_path=folder) elif format == "memory": - # TODO copy traces list to standard numpy to avoid shms not handled traces_list, shms = write_memory_recording(self, dtype=None, **job_kwargs) - from .numpyextractors import NumpyRecording + if shms[0] is not None: + # if the computation was done in parrralel then traces_list is shared array + # this can lead to problem + # we need to copy back to a standard numpy array and unlink the shared buffer + traces_list = [np.array(traces, copy=True) for traces in traces_list] + for shm in shms: + shm.close() + shm.unlink() + from .numpyextractors import NumpyRecording cached = NumpyRecording( traces_list, self.get_sampling_frequency(), t_starts=t_starts, channel_ids=self.channel_ids ) + elif format == "sharedmemory": from .numpyextractors import SharedMemoryRecording cached = SharedMemoryRecording.from_recording(self, **job_kwargs) From e4ed761e42e52309ee0e3bf24d673dc47d6d0f68 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 21 Dec 2023 18:02:26 +0100 Subject: [PATCH 05/11] Change my mind onn the signature --- src/spikeinterface/core/base.py | 14 +++---------- src/spikeinterface/core/baserecording.py | 23 +++++----------------- src/spikeinterface/core/basesorting.py | 14 ++++++------- src/spikeinterface/core/numpyextractors.py | 18 +++++++++++++++++ 4 files changed, 32 insertions(+), 37 deletions(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index ef41875da7..19ad61a0f2 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -839,8 +839,6 @@ def save(self, **kwargs) -> "BaseExtractor": format = kwargs.get("format", None) if format == "memory": loaded_extractor = self.save_to_memory(**kwargs) - elif format == "sharedmemory": - loaded_extractor = self.save_to_sharedmemory(**kwargs) elif format == "zarr": loaded_extractor = self.save_to_zarr(**kwargs) else: @@ -849,17 +847,11 @@ def save(self, **kwargs) -> "BaseExtractor": save.__doc__ = save.__doc__.format(_shared_job_kwargs_doc) - def save_to_memory(self, **kwargs) -> "BaseExtractor": + def save_to_memory(self, sharedmem=True, **save_kwargs) -> "BaseExtractor": - kwargs.pop("format", None) - - cached = self._save(format="memory", **kwargs) - self.copy_metadata(cached) - return cached + save_kwargs.pop("format", None) - def save_to_sharedmemory(self, **kwargs) -> "BaseExtractor": - kwargs.pop("format", None) - cached = self._save(format="sharedmemory", **kwargs) + cached = self._save(format="memory", sharedmem=sharedmem, **save_kwargs) self.copy_metadata(cached) return cached diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index fb2dcb9740..8b870eb4fe 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -485,24 +485,11 @@ def _save(self, format="binary", **save_kwargs): cached = BinaryFolderRecording(folder_path=folder) elif format == "memory": - traces_list, shms = write_memory_recording(self, dtype=None, **job_kwargs) - if shms[0] is not None: - # if the computation was done in parrralel then traces_list is shared array - # this can lead to problem - # we need to copy back to a standard numpy array and unlink the shared buffer - traces_list = [np.array(traces, copy=True) for traces in traces_list] - for shm in shms: - shm.close() - shm.unlink() - - from .numpyextractors import NumpyRecording - cached = NumpyRecording( - traces_list, self.get_sampling_frequency(), t_starts=t_starts, channel_ids=self.channel_ids - ) - - elif format == "sharedmemory": - from .numpyextractors import SharedMemoryRecording - cached = SharedMemoryRecording.from_recording(self, **job_kwargs) + if kwargs.get("sharedmem", True): + from .numpyextractors import SharedMemoryRecording + cached = SharedMemoryRecording.from_recording(self, **job_kwargs) + else: + cached = NumpyRecording.from_recording(self, **job_kwargs) elif format == "zarr": from .zarrextractors import ZarrRecordingExtractor diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 23a4962623..2040288924 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -264,14 +264,12 @@ def _save(self, format="numpy_folder", **save_kwargs): cached.register_recording(self._recording) elif format == "memory": - from .numpyextractors import NumpySorting - - cached = NumpySorting.from_sorting(self) - - elif format == "sharedmemory": - from .numpyextractors import SharedMemorySorting - cached = SharedMemorySorting.from_sorting(self) - + if save_kwargs.get("sharedmem", True): + from .numpyextractors import SharedMemorySorting + cached = SharedMemorySorting.from_sorting(self) + else: + from .numpyextractors import NumpySorting + cached = NumpySorting.from_sorting(self) else: raise ValueError(f"format {format} not supported") return cached diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index 11a5c8a4f1..1a8c0c4dfa 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -84,6 +84,22 @@ def __init__(self, traces_list, sampling_frequency, t_starts=None, channel_ids=N "sampling_frequency": sampling_frequency, } + @staticmethod + def from_recording(source_recording, **job_kwargs): + traces_list, shms = write_memory_recording(source_recording, dtype=None, **job_kwargs) + if shms[0] is not None: + # if the computation was done in parrralel then traces_list is shared array + # this can lead to problem + # we need to copy back to a standard numpy array and unlink the shared buffer + traces_list = [np.array(traces, copy=True) for traces in traces_list] + for shm in shms: + shm.close() + shm.unlink() + # TODO later : propagte t_starts ? + recording = NumpyRecording(traces_list, source_recording.get_sampling_frequency(), + t_starts=None, channel_ids=source_recording.channel_ids) + + class NumpyRecordingSegment(BaseRecordingSegment): def __init__(self, traces, sampling_frequency, t_start): @@ -193,6 +209,8 @@ def __del__(self): def from_recording(source_recording, **job_kwargs): traces_list, shms = write_memory_recording(source_recording, buffer_type="sharedmem", **job_kwargs) + # TODO later : propagte t_starts ? + recording = SharedMemoryRecording( shm_names=[shm.name for shm in shms], shape_list = [traces.shape for traces in traces_list], From f84e62067d67cee80d190e1b19c6cd535cc5123f Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 21 Dec 2023 21:12:09 +0100 Subject: [PATCH 06/11] oups --- src/spikeinterface/core/tests/test_baserecording.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/tests/test_baserecording.py b/src/spikeinterface/core/tests/test_baserecording.py index c820e5899c..682d9344a7 100644 --- a/src/spikeinterface/core/tests/test_baserecording.py +++ b/src/spikeinterface/core/tests/test_baserecording.py @@ -162,13 +162,13 @@ def test_BaseRecording(): rec3 = BaseExtractor.load(cache_folder / "simple_recording") # cache to memory - rec4 = rec3.save(format="memory") + rec4 = rec3.save(format="memory", sahred=False) traces4 = rec4.get_traces(segment_index=0) traces = rec.get_traces(segment_index=0) assert np.array_equal(traces4, traces) # cache to sharedmemory - rec5 = rec3.save(format="sharedmemory") + rec5 = rec3.save(format="memory", sahred=True) traces5 = rec5.get_traces(segment_index=0) traces = rec.get_traces(segment_index=0) assert np.array_equal(traces5, traces) From 1b66a60c75a2699f6740001ce748ab6abd8c56e9 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 9 Jan 2024 10:54:32 +0100 Subject: [PATCH 07/11] pierre's comment --- src/spikeinterface/core/tests/test_baserecording.py | 4 ++-- src/spikeinterface/core/tests/test_numpy_extractors.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/core/tests/test_baserecording.py b/src/spikeinterface/core/tests/test_baserecording.py index 682d9344a7..f4f3891ddf 100644 --- a/src/spikeinterface/core/tests/test_baserecording.py +++ b/src/spikeinterface/core/tests/test_baserecording.py @@ -162,13 +162,13 @@ def test_BaseRecording(): rec3 = BaseExtractor.load(cache_folder / "simple_recording") # cache to memory - rec4 = rec3.save(format="memory", sahred=False) + rec4 = rec3.save(format="memory", shared=False) traces4 = rec4.get_traces(segment_index=0) traces = rec.get_traces(segment_index=0) assert np.array_equal(traces4, traces) # cache to sharedmemory - rec5 = rec3.save(format="memory", sahred=True) + rec5 = rec3.save(format="memory", shared=True) traces5 = rec5.get_traces(segment_index=0) traces = rec.get_traces(segment_index=0) assert np.array_equal(traces5, traces) diff --git a/src/spikeinterface/core/tests/test_numpy_extractors.py b/src/spikeinterface/core/tests/test_numpy_extractors.py index 8a31783cb7..88cec043c7 100644 --- a/src/spikeinterface/core/tests/test_numpy_extractors.py +++ b/src/spikeinterface/core/tests/test_numpy_extractors.py @@ -23,7 +23,7 @@ def test_NumpyRecording(): timeseries_list.append(traces) rec = NumpyRecording(timeseries_list, sampling_frequency) - print(rec) + # print(rec) times1 = rec.get_times(1) @@ -31,7 +31,7 @@ def test_NumpyRecording(): def test_SharedMemoryRecording(): rec0 = generate_recording(num_channels=2, durations=[4., 3.]) - print(rec0) + # print(rec0) job_kwargs = dict(n_jobs=1, progress_bar=True) rec = SharedMemoryRecording.from_recording(rec0, **job_kwargs) @@ -63,7 +63,7 @@ def test_NumpySorting(): labels[1::3] = 1 labels[2::3] = 2 sorting = NumpySorting.from_times_labels(times, labels, sampling_frequency) - print(sorting) + # print(sorting) assert sorting.get_num_segments() == 1 sorting = NumpySorting.from_times_labels([times] * 3, [labels] * 3, sampling_frequency) @@ -93,7 +93,7 @@ def test_SharedMemorySorting(): spikes["unit_index"][1::3] = 1 spikes["unit_index"][2::3] = 2 np_sorting = NumpySorting(spikes, sampling_frequency, unit_ids) - print(np_sorting) + # print(np_sorting) sorting = SharedMemorySorting.from_sorting(np_sorting) # print(sorting) From bb5fe3a175a6e3e03997a4bfa2165d5a0cecd488 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 12 Jan 2024 12:07:28 +0000 Subject: [PATCH 08/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/__init__.py | 9 +++++- src/spikeinterface/core/base.py | 1 - src/spikeinterface/core/baserecording.py | 2 +- src/spikeinterface/core/basesorting.py | 2 ++ src/spikeinterface/core/numpyextractors.py | 30 ++++++++++--------- .../core/tests/test_baserecording.py | 1 - .../core/tests/test_numpy_extractors.py | 20 +++++++++---- .../core/tests/test_recording_tools.py | 4 ++- 8 files changed, 45 insertions(+), 24 deletions(-) diff --git a/src/spikeinterface/core/__init__.py b/src/spikeinterface/core/__init__.py index b6da37277a..1eb01d27ce 100644 --- a/src/spikeinterface/core/__init__.py +++ b/src/spikeinterface/core/__init__.py @@ -8,7 +8,14 @@ # main extractor from dump and cache from .binaryrecordingextractor import BinaryRecordingExtractor, read_binary from .npzsortingextractor import NpzSortingExtractor, read_npz_sorting -from .numpyextractors import NumpyRecording, SharedMemoryRecording, NumpySorting, SharedMemorySorting, NumpyEvent, NumpySnippets +from .numpyextractors import ( + NumpyRecording, + SharedMemoryRecording, + NumpySorting, + SharedMemorySorting, + NumpyEvent, + NumpySnippets, +) from .zarrextractors import ZarrRecordingExtractor, ZarrSortingExtractor, read_zarr, get_default_zarr_compressor from .binaryfolder import BinaryFolderRecording, read_binary_folder from .sortingfolder import NumpyFolderSorting, NpzFolderSorting, read_numpy_sorting_folder, read_npz_folder diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 19ad61a0f2..80341811b9 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -848,7 +848,6 @@ def save(self, **kwargs) -> "BaseExtractor": save.__doc__ = save.__doc__.format(_shared_job_kwargs_doc) def save_to_memory(self, sharedmem=True, **save_kwargs) -> "BaseExtractor": - save_kwargs.pop("format", None) cached = self._save(format="memory", sharedmem=sharedmem, **save_kwargs) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 775ceaee0d..b65409e033 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -441,7 +441,6 @@ def time_to_sample_index(self, time_s, segment_index=None): return rs.time_to_sample_index(time_s) def _save(self, format="binary", **save_kwargs): - # handle t_starts t_starts = [] has_time_vectors = [] @@ -487,6 +486,7 @@ def _save(self, format="binary", **save_kwargs): elif format == "memory": if kwargs.get("sharedmem", True): from .numpyextractors import SharedMemoryRecording + cached = SharedMemoryRecording.from_recording(self, **job_kwargs) else: cached = NumpyRecording.from_recording(self, **job_kwargs) diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 2040288924..ba4f4b4850 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -266,9 +266,11 @@ def _save(self, format="numpy_folder", **save_kwargs): elif format == "memory": if save_kwargs.get("sharedmem", True): from .numpyextractors import SharedMemorySorting + cached = SharedMemorySorting.from_sorting(self) else: from .numpyextractors import NumpySorting + cached = NumpySorting.from_sorting(self) else: raise ValueError(f"format {format} not supported") diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index 1a8c0c4dfa..35079ee694 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -96,9 +96,12 @@ def from_recording(source_recording, **job_kwargs): shm.close() shm.unlink() # TODO later : propagte t_starts ? - recording = NumpyRecording(traces_list, source_recording.get_sampling_frequency(), - t_starts=None, channel_ids=source_recording.channel_ids) - + recording = NumpyRecording( + traces_list, + source_recording.get_sampling_frequency(), + t_starts=None, + channel_ids=source_recording.channel_ids, + ) class NumpyRecordingSegment(BaseRecordingSegment): @@ -116,13 +119,12 @@ def get_traces(self, start_frame, end_frame, channel_indices): traces = traces[:, channel_indices] return traces - class SharedMemoryRecording(BaseRecording): """ In memory recording with shared memmory buffer. - + Parameters ---------- @@ -137,7 +139,7 @@ class SharedMemoryRecording(BaseRecording): channel_ids: list An optional list of channel_ids. If None, linear channels are assumed """ - + extractor_name = "SharedMemory" mode = "memory" name = "SharedMemory" @@ -145,8 +147,6 @@ class SharedMemoryRecording(BaseRecording): def __init__( self, shm_names, shape_list, dtype, sampling_frequency, channel_ids=None, t_starts=None, main_shm_owner=True ): - - assert len(shape_list) == len(shm_names) assert all(shape_list[0][1] == shape[1] for shape in shape_list) @@ -177,7 +177,10 @@ def __init__( # this is important so that the main owner can unlink the share mem buffer : self.main_shm_owner = main_shm_owner - for i, traces, in enumerate(traces_list): + for ( + i, + traces, + ) in enumerate(traces_list): if t_starts is None: t_start = None else: @@ -197,9 +200,8 @@ def __init__( "main_shm_owner": False, } - def __del__(self): - self._recording_segments =[] + self._recording_segments = [] for shm in self.shms: shm.close() if self.main_shm_owner: @@ -210,15 +212,15 @@ def from_recording(source_recording, **job_kwargs): traces_list, shms = write_memory_recording(source_recording, buffer_type="sharedmem", **job_kwargs) # TODO later : propagte t_starts ? - + recording = SharedMemoryRecording( shm_names=[shm.name for shm in shms], - shape_list = [traces.shape for traces in traces_list], + shape_list=[traces.shape for traces in traces_list], dtype=source_recording.dtype, sampling_frequency=source_recording.sampling_frequency, channel_ids=source_recording.channel_ids, t_starts=None, - main_shm_owner=True + main_shm_owner=True, ) for shm in shms: diff --git a/src/spikeinterface/core/tests/test_baserecording.py b/src/spikeinterface/core/tests/test_baserecording.py index f4f3891ddf..d6096ff397 100644 --- a/src/spikeinterface/core/tests/test_baserecording.py +++ b/src/spikeinterface/core/tests/test_baserecording.py @@ -173,7 +173,6 @@ def test_BaseRecording(): traces = rec.get_traces(segment_index=0) assert np.array_equal(traces5, traces) - # cache joblib several jobs folder = cache_folder / "simple_recording2" rec2 = rec.save(format="binary", folder=folder, chunk_size=10, n_jobs=4) diff --git a/src/spikeinterface/core/tests/test_numpy_extractors.py b/src/spikeinterface/core/tests/test_numpy_extractors.py index 88cec043c7..c694026918 100644 --- a/src/spikeinterface/core/tests/test_numpy_extractors.py +++ b/src/spikeinterface/core/tests/test_numpy_extractors.py @@ -4,8 +4,18 @@ import pytest import numpy as np -from spikeinterface.core import (BaseRecording, NumpyRecording, SharedMemoryRecording, NumpySorting, SharedMemorySorting, -NumpyEvent, create_sorting_npz, load_extractor, NpzSortingExtractor, generate_recording) +from spikeinterface.core import ( + BaseRecording, + NumpyRecording, + SharedMemoryRecording, + NumpySorting, + SharedMemorySorting, + NumpyEvent, + create_sorting_npz, + load_extractor, + NpzSortingExtractor, + generate_recording, +) from spikeinterface.core.basesorting import minimum_spike_dtype @@ -29,17 +39,17 @@ def test_NumpyRecording(): rec.save(folder=cache_folder / "test_NumpyRecording") + def test_SharedMemoryRecording(): - rec0 = generate_recording(num_channels=2, durations=[4., 3.]) + rec0 = generate_recording(num_channels=2, durations=[4.0, 3.0]) # print(rec0) job_kwargs = dict(n_jobs=1, progress_bar=True) rec = SharedMemoryRecording.from_recording(rec0, **job_kwargs) - + d = rec.to_dict() rec_clone = load_extractor(d) traces = rec_clone.get_traces(start_frame=0, end_frame=30000, segment_index=0) - assert rec.shms[0].name == rec_clone.shms[0].name del traces diff --git a/src/spikeinterface/core/tests/test_recording_tools.py b/src/spikeinterface/core/tests/test_recording_tools.py index 6f3c2938c7..5de187473b 100644 --- a/src/spikeinterface/core/tests/test_recording_tools.py +++ b/src/spikeinterface/core/tests/test_recording_tools.py @@ -151,7 +151,9 @@ def test_write_memory_recording(): # write with loop traces_list, shms = write_memory_recording(recording, dtype=None, verbose=True, n_jobs=1) - traces_list, shms = write_memory_recording(recording, dtype=None, verbose=True, n_jobs=1, chunk_memory="100k", progress_bar=True) + traces_list, shms = write_memory_recording( + recording, dtype=None, verbose=True, n_jobs=1, chunk_memory="100k", progress_bar=True + ) # write parrallel traces_list, shms = write_memory_recording(recording, dtype=None, verbose=False, n_jobs=2, chunk_memory="100k") From 8a5e6563451401cae790a584769e8befd0e13e3a Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 12 Jan 2024 13:12:30 +0100 Subject: [PATCH 09/11] Apply suggestions from code review Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/core/numpyextractors.py | 8 ++++---- src/spikeinterface/core/tests/test_recording_tools.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index 35079ee694..dd86b42428 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -88,7 +88,7 @@ def __init__(self, traces_list, sampling_frequency, t_starts=None, channel_ids=N def from_recording(source_recording, **job_kwargs): traces_list, shms = write_memory_recording(source_recording, dtype=None, **job_kwargs) if shms[0] is not None: - # if the computation was done in parrralel then traces_list is shared array + # if the computation was done in parallel then traces_list is shared array # this can lead to problem # we need to copy back to a standard numpy array and unlink the shared buffer traces_list = [np.array(traces, copy=True) for traces in traces_list] @@ -147,10 +147,10 @@ class SharedMemoryRecording(BaseRecording): def __init__( self, shm_names, shape_list, dtype, sampling_frequency, channel_ids=None, t_starts=None, main_shm_owner=True ): - assert len(shape_list) == len(shm_names) + assert len(shape_list) == len(shm_names), 'Each shm_name in `shm_names` must have a shape in `shape_list`' assert all(shape_list[0][1] == shape[1] for shape in shape_list) - # create traces from shraedmem names + # create traces from sharedmem names self.shms = [] traces_list = [] for shm_name, shape in zip(shm_names, shape_list): @@ -167,7 +167,7 @@ def __init__( BaseRecording.__init__(self, sampling_frequency, channel_ids, dtype) if t_starts is not None: - assert len(t_starts) == len(traces_list), "t_starts must be a list of same size than traces_list" + assert len(t_starts) == len(traces_list), "t_starts must be a list of same size as traces_list" t_starts = [float(t_start) for t_start in t_starts] self._serializability["memory"] = True diff --git a/src/spikeinterface/core/tests/test_recording_tools.py b/src/spikeinterface/core/tests/test_recording_tools.py index 5de187473b..5e0b77a151 100644 --- a/src/spikeinterface/core/tests/test_recording_tools.py +++ b/src/spikeinterface/core/tests/test_recording_tools.py @@ -155,7 +155,7 @@ def test_write_memory_recording(): recording, dtype=None, verbose=True, n_jobs=1, chunk_memory="100k", progress_bar=True ) - # write parrallel + # write parallel traces_list, shms = write_memory_recording(recording, dtype=None, verbose=False, n_jobs=2, chunk_memory="100k") # need to clean the buffer del traces_list From 35b9bf528548cd213f2db0942b9b39972e7f2b38 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 12 Jan 2024 12:12:47 +0000 Subject: [PATCH 10/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/numpyextractors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index dd86b42428..5c5db31f88 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -147,7 +147,7 @@ class SharedMemoryRecording(BaseRecording): def __init__( self, shm_names, shape_list, dtype, sampling_frequency, channel_ids=None, t_starts=None, main_shm_owner=True ): - assert len(shape_list) == len(shm_names), 'Each shm_name in `shm_names` must have a shape in `shape_list`' + assert len(shape_list) == len(shm_names), "Each shm_name in `shm_names` must have a shape in `shape_list`" assert all(shape_list[0][1] == shape[1] for shape in shape_list) # create traces from sharedmem names From cd6802d29afa8bcf5eceefc0d4425fa5bc804c66 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 12 Jan 2024 13:15:34 +0100 Subject: [PATCH 11/11] More docstrings --- src/spikeinterface/core/numpyextractors.py | 37 ++++++++++++---------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index 1a8c0c4dfa..1c09155722 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -96,9 +96,12 @@ def from_recording(source_recording, **job_kwargs): shm.close() shm.unlink() # TODO later : propagte t_starts ? - recording = NumpyRecording(traces_list, source_recording.get_sampling_frequency(), - t_starts=None, channel_ids=source_recording.channel_ids) - + recording = NumpyRecording( + traces_list, + source_recording.get_sampling_frequency(), + t_starts=None, + channel_ids=source_recording.channel_ids, + ) class NumpyRecordingSegment(BaseRecordingSegment): @@ -116,28 +119,30 @@ def get_traces(self, start_frame, end_frame, channel_indices): traces = traces[:, channel_indices] return traces - class SharedMemoryRecording(BaseRecording): """ In memory recording with shared memmory buffer. - Parameters ---------- shm_names: list - List of sharedmem names. + List of sharedmem names for each segment shape_list: list - + List of shape of sharedmem buffer for each segment + The first dimension is the number of samples, the second is the number of channels. + Note that the number of channels must be the same for all segments sampling_frequency: float The sampling frequency in Hz t_starts: None or list of float Times in seconds of the first sample for each segment channel_ids: list An optional list of channel_ids. If None, linear channels are assumed + main_shm_owner: bool, default: True + If True, the main instance will unlink the sharedmem buffer when deleted """ - + extractor_name = "SharedMemory" mode = "memory" name = "SharedMemory" @@ -145,8 +150,6 @@ class SharedMemoryRecording(BaseRecording): def __init__( self, shm_names, shape_list, dtype, sampling_frequency, channel_ids=None, t_starts=None, main_shm_owner=True ): - - assert len(shape_list) == len(shm_names) assert all(shape_list[0][1] == shape[1] for shape in shape_list) @@ -177,7 +180,10 @@ def __init__( # this is important so that the main owner can unlink the share mem buffer : self.main_shm_owner = main_shm_owner - for i, traces, in enumerate(traces_list): + for ( + i, + traces, + ) in enumerate(traces_list): if t_starts is None: t_start = None else: @@ -197,9 +203,8 @@ def __init__( "main_shm_owner": False, } - def __del__(self): - self._recording_segments =[] + self._recording_segments = [] for shm in self.shms: shm.close() if self.main_shm_owner: @@ -210,15 +215,15 @@ def from_recording(source_recording, **job_kwargs): traces_list, shms = write_memory_recording(source_recording, buffer_type="sharedmem", **job_kwargs) # TODO later : propagte t_starts ? - + recording = SharedMemoryRecording( shm_names=[shm.name for shm in shms], - shape_list = [traces.shape for traces in traces_list], + shape_list=[traces.shape for traces in traces_list], dtype=source_recording.dtype, sampling_frequency=source_recording.sampling_frequency, channel_ids=source_recording.channel_ids, t_starts=None, - main_shm_owner=True + main_shm_owner=True, ) for shm in shms: