diff --git a/src/spikeinterface/core/__init__.py b/src/spikeinterface/core/__init__.py index 55d3b6360a..1eb01d27ce 100644 --- a/src/spikeinterface/core/__init__.py +++ b/src/spikeinterface/core/__init__.py @@ -8,7 +8,14 @@ # main extractor from dump and cache from .binaryrecordingextractor import BinaryRecordingExtractor, read_binary from .npzsortingextractor import NpzSortingExtractor, read_npz_sorting -from .numpyextractors import NumpyRecording, NumpySorting, SharedMemorySorting, NumpyEvent, NumpySnippets +from .numpyextractors import ( + NumpyRecording, + SharedMemoryRecording, + NumpySorting, + SharedMemorySorting, + NumpyEvent, + NumpySnippets, +) from .zarrextractors import ZarrRecordingExtractor, ZarrSortingExtractor, read_zarr, get_default_zarr_compressor from .binaryfolder import BinaryFolderRecording, read_binary_folder from .sortingfolder import NumpyFolderSorting, NpzFolderSorting, read_numpy_sorting_folder, read_npz_folder diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 5e8b8f9f1b..80341811b9 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -847,9 +847,10 @@ def save(self, **kwargs) -> "BaseExtractor": save.__doc__ = save.__doc__.format(_shared_job_kwargs_doc) - def save_to_memory(self, **kwargs) -> "BaseExtractor": - # used only by recording at the moment - cached = self._save(**kwargs) + def save_to_memory(self, sharedmem=True, **save_kwargs) -> "BaseExtractor": + save_kwargs.pop("format", None) + + cached = self._save(format="memory", sharedmem=sharedmem, **save_kwargs) self.copy_metadata(cached) return cached @@ -978,6 +979,8 @@ def save_to_zarr( import zarr from .zarrextractors import read_zarr + save_kwargs.pop("format", None) + if folder is None: cache_folder = get_global_tmp_folder() if name is None: @@ -1006,7 +1009,7 @@ def save_to_zarr( save_kwargs["zarr_path"] = zarr_path save_kwargs["storage_options"] = storage_options save_kwargs["channel_chunk_size"] = channel_chunk_size - cached = self._save(verbose=verbose, **save_kwargs) + cached = self._save(format="zarr", verbose=verbose, **save_kwargs) cached = read_zarr(zarr_path) return cached diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 18bb0beb6d..b65409e033 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -441,12 +441,6 @@ def time_to_sample_index(self, time_s, segment_index=None): return rs.time_to_sample_index(time_s) def _save(self, format="binary", **save_kwargs): - """ - This function replaces the old CacheRecordingExtractor, but enables more engines - for caching a results. At the moment only "binary" with memmap is supported. - We plan to add other engines, such as zarr and NWB. - """ - # handle t_starts t_starts = [] has_time_vectors = [] @@ -490,12 +484,12 @@ def _save(self, format="binary", **save_kwargs): cached = BinaryFolderRecording(folder_path=folder) elif format == "memory": - traces_list = write_memory_recording(self, dtype=None, **job_kwargs) - from .numpyextractors import NumpyRecording + if kwargs.get("sharedmem", True): + from .numpyextractors import SharedMemoryRecording - cached = NumpyRecording( - traces_list, self.get_sampling_frequency(), t_starts=t_starts, channel_ids=self.channel_ids - ) + cached = SharedMemoryRecording.from_recording(self, **job_kwargs) + else: + cached = NumpyRecording.from_recording(self, **job_kwargs) elif format == "zarr": from .zarrextractors import ZarrRecordingExtractor diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 465807cc98..ba4f4b4850 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -264,9 +264,14 @@ def _save(self, format="numpy_folder", **save_kwargs): cached.register_recording(self._recording) elif format == "memory": - from .numpyextractors import NumpySorting + if save_kwargs.get("sharedmem", True): + from .numpyextractors import SharedMemorySorting - cached = NumpySorting.from_sorting(self) + cached = SharedMemorySorting.from_sorting(self) + else: + from .numpyextractors import NumpySorting + + cached = NumpySorting.from_sorting(self) else: raise ValueError(f"format {format} not supported") return cached diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index 88db042512..e50e92e33d 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -13,9 +13,10 @@ ) from .basesorting import minimum_spike_dtype from .core_tools import make_shared_array - +from .recording_tools import write_memory_recording from multiprocessing.shared_memory import SharedMemory + from typing import Union @@ -83,6 +84,25 @@ def __init__(self, traces_list, sampling_frequency, t_starts=None, channel_ids=N "sampling_frequency": sampling_frequency, } + @staticmethod + def from_recording(source_recording, **job_kwargs): + traces_list, shms = write_memory_recording(source_recording, dtype=None, **job_kwargs) + if shms[0] is not None: + # if the computation was done in parallel then traces_list is shared array + # this can lead to problem + # we need to copy back to a standard numpy array and unlink the shared buffer + traces_list = [np.array(traces, copy=True) for traces in traces_list] + for shm in shms: + shm.close() + shm.unlink() + # TODO later : propagte t_starts ? + recording = NumpyRecording( + traces_list, + source_recording.get_sampling_frequency(), + t_starts=None, + channel_ids=source_recording.channel_ids, + ) + class NumpyRecordingSegment(BaseRecordingSegment): def __init__(self, traces, sampling_frequency, t_start): @@ -101,6 +121,117 @@ def get_traces(self, start_frame, end_frame, channel_indices): return traces +class SharedMemoryRecording(BaseRecording): + """ + In memory recording with shared memmory buffer. + + Parameters + ---------- + shm_names: list + List of sharedmem names for each segment + shape_list: list + List of shape of sharedmem buffer for each segment + The first dimension is the number of samples, the second is the number of channels. + Note that the number of channels must be the same for all segments + sampling_frequency: float + The sampling frequency in Hz + t_starts: None or list of float + Times in seconds of the first sample for each segment + channel_ids: list + An optional list of channel_ids. If None, linear channels are assumed + main_shm_owner: bool, default: True + If True, the main instance will unlink the sharedmem buffer when deleted + """ + + extractor_name = "SharedMemory" + mode = "memory" + name = "SharedMemory" + + def __init__( + self, shm_names, shape_list, dtype, sampling_frequency, channel_ids=None, t_starts=None, main_shm_owner=True + ): + assert len(shape_list) == len(shm_names), "Each shm_name in `shm_names` must have a shape in `shape_list`" + assert all(shape_list[0][1] == shape[1] for shape in shape_list) + + # create traces from sharedmem names + self.shms = [] + traces_list = [] + for shm_name, shape in zip(shm_names, shape_list): + shm = SharedMemory(shm_name, create=False) + self.shms.append(shm) + traces = np.ndarray(shape=shape, dtype=dtype, buffer=shm.buf) + traces_list.append(traces) + + if channel_ids is None: + channel_ids = np.arange(traces_list[0].shape[1]) + else: + channel_ids = np.asarray(channel_ids) + assert channel_ids.size == traces_list[0].shape[1] + BaseRecording.__init__(self, sampling_frequency, channel_ids, dtype) + + if t_starts is not None: + assert len(t_starts) == len(traces_list), "t_starts must be a list of same size as traces_list" + t_starts = [float(t_start) for t_start in t_starts] + + self._serializability["memory"] = True + self._serializability["json"] = False + self._serializability["pickle"] = False + + # this is important so that the main owner can unlink the share mem buffer : + self.main_shm_owner = main_shm_owner + + for ( + i, + traces, + ) in enumerate(traces_list): + if t_starts is None: + t_start = None + else: + t_start = t_starts[i] + rec_segment = NumpyRecordingSegment(traces, sampling_frequency, t_start) + + self.add_recording_segment(rec_segment) + + self._kwargs = { + "shm_names": shm_names, + "shape_list": shape_list, + "dtype": dtype, + "sampling_frequency": sampling_frequency, + "channel_ids": channel_ids, + "t_starts": t_starts, + # this is important so that clone of this will not try to unlink the share mem buffer : + "main_shm_owner": False, + } + + def __del__(self): + self._recording_segments = [] + for shm in self.shms: + shm.close() + if self.main_shm_owner: + shm.unlink() + + @staticmethod + def from_recording(source_recording, **job_kwargs): + traces_list, shms = write_memory_recording(source_recording, buffer_type="sharedmem", **job_kwargs) + + # TODO later : propagte t_starts ? + + recording = SharedMemoryRecording( + shm_names=[shm.name for shm in shms], + shape_list=[traces.shape for traces in traces_list], + dtype=source_recording.dtype, + sampling_frequency=source_recording.sampling_frequency, + channel_ids=source_recording.channel_ids, + t_starts=None, + main_shm_owner=True, + ) + + for shm in shms: + # the sharedmem are handle by the new SharedMemoryRecording + shm.close() + return recording + + class NumpySorting(BaseSorting): """ In memory sorting object. diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index a8f12c8a1c..6341b5b09a 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -269,7 +269,7 @@ def _write_memory_chunk(segment_index, start_frame, end_frame, worker_ctx): arr[start_frame:end_frame, :] = traces -def write_memory_recording(recording, dtype=None, verbose=False, auto_cast_uint=True, **job_kwargs): +def write_memory_recording(recording, dtype=None, verbose=False, auto_cast_uint=True, buffer_type="auto", **job_kwargs): """ Save the traces into numpy arrays (memory). try to use the SharedMemory introduce in py3.8 if n_jobs > 1 @@ -284,6 +284,7 @@ def write_memory_recording(recording, dtype=None, verbose=False, auto_cast_uint= If True, output is verbose (when chunks are used) auto_cast_uint: bool, default: True If True, unsigned integers are automatically cast to int if the specified dtype is signed + buffer_type: "auto" | "numpy" | "sharedmem" {} Returns @@ -302,19 +303,28 @@ def write_memory_recording(recording, dtype=None, verbose=False, auto_cast_uint= # create sharedmmep arrays = [] shm_names = [] + shms = [] shapes = [] n_jobs = ensure_n_jobs(recording, n_jobs=job_kwargs.get("n_jobs", 1)) + if buffer_type == "auto": + if n_jobs > 1: + buffer_type = "sharedmem" + else: + buffer_type = "numpy" + for segment_index in range(recording.get_num_segments()): num_frames = recording.get_num_samples(segment_index) num_channels = recording.get_num_channels() shape = (num_frames, num_channels) shapes.append(shape) - if n_jobs > 1: + if buffer_type == "sharedmem": arr, shm = make_shared_array(shape, dtype) shm_names.append(shm.name) + shms.append(shm) else: arr = np.zeros(shape, dtype=dtype) + shms.append(None) arrays.append(arr) # use executor (loop or workers) @@ -330,7 +340,7 @@ def write_memory_recording(recording, dtype=None, verbose=False, auto_cast_uint= ) executor.run() - return arrays + return arrays, shms write_memory_recording.__doc__ = write_memory_recording.__doc__.format(_shared_job_kwargs_doc) diff --git a/src/spikeinterface/core/tests/test_baserecording.py b/src/spikeinterface/core/tests/test_baserecording.py index 2986b87985..d6096ff397 100644 --- a/src/spikeinterface/core/tests/test_baserecording.py +++ b/src/spikeinterface/core/tests/test_baserecording.py @@ -162,15 +162,20 @@ def test_BaseRecording(): rec3 = BaseExtractor.load(cache_folder / "simple_recording") # cache to memory - rec4 = rec3.save(format="memory") - + rec4 = rec3.save(format="memory", shared=False) traces4 = rec4.get_traces(segment_index=0) traces = rec.get_traces(segment_index=0) assert np.array_equal(traces4, traces) + # cache to sharedmemory + rec5 = rec3.save(format="memory", shared=True) + traces5 = rec5.get_traces(segment_index=0) + traces = rec.get_traces(segment_index=0) + assert np.array_equal(traces5, traces) + # cache joblib several jobs folder = cache_folder / "simple_recording2" - rec2 = rec.save(folder=folder, chunk_size=10, n_jobs=4) + rec2 = rec.save(format="binary", folder=folder, chunk_size=10, n_jobs=4) traces2 = rec2.get_traces(segment_index=0) # set/get Probe only 2 channels diff --git a/src/spikeinterface/core/tests/test_numpy_extractors.py b/src/spikeinterface/core/tests/test_numpy_extractors.py index 4a5bffbc05..c694026918 100644 --- a/src/spikeinterface/core/tests/test_numpy_extractors.py +++ b/src/spikeinterface/core/tests/test_numpy_extractors.py @@ -4,9 +4,19 @@ import pytest import numpy as np -from spikeinterface.core import NumpyRecording, NumpySorting, SharedMemorySorting, NumpyEvent -from spikeinterface.core import create_sorting_npz, load_extractor -from spikeinterface.core import NpzSortingExtractor +from spikeinterface.core import ( + BaseRecording, + NumpyRecording, + SharedMemoryRecording, + NumpySorting, + SharedMemorySorting, + NumpyEvent, + create_sorting_npz, + load_extractor, + NpzSortingExtractor, + generate_recording, +) + from spikeinterface.core.basesorting import minimum_spike_dtype if hasattr(pytest, "global_test_folder"): @@ -23,13 +33,30 @@ def test_NumpyRecording(): timeseries_list.append(traces) rec = NumpyRecording(timeseries_list, sampling_frequency) - print(rec) + # print(rec) times1 = rec.get_times(1) rec.save(folder=cache_folder / "test_NumpyRecording") +def test_SharedMemoryRecording(): + rec0 = generate_recording(num_channels=2, durations=[4.0, 3.0]) + # print(rec0) + job_kwargs = dict(n_jobs=1, progress_bar=True) + rec = SharedMemoryRecording.from_recording(rec0, **job_kwargs) + + d = rec.to_dict() + rec_clone = load_extractor(d) + traces = rec_clone.get_traces(start_frame=0, end_frame=30000, segment_index=0) + + assert rec.shms[0].name == rec_clone.shms[0].name + + del traces + del rec_clone + del rec + + def test_NumpySorting(): sampling_frequency = 30000 @@ -46,7 +73,7 @@ def test_NumpySorting(): labels[1::3] = 1 labels[2::3] = 2 sorting = NumpySorting.from_times_labels(times, labels, sampling_frequency) - print(sorting) + # print(sorting) assert sorting.get_num_segments() == 1 sorting = NumpySorting.from_times_labels([times] * 3, [labels] * 3, sampling_frequency) @@ -76,7 +103,7 @@ def test_SharedMemorySorting(): spikes["unit_index"][1::3] = 1 spikes["unit_index"][2::3] = 2 np_sorting = NumpySorting(spikes, sampling_frequency, unit_ids) - print(np_sorting) + # print(np_sorting) sorting = SharedMemorySorting.from_sorting(np_sorting) # print(sorting) @@ -132,6 +159,7 @@ def test_NumpyEvent(): if __name__ == "__main__": # test_NumpyRecording() - test_NumpySorting() + test_SharedMemoryRecording() + # test_NumpySorting() # test_SharedMemorySorting() # test_NumpyEvent() diff --git a/src/spikeinterface/core/tests/test_recording_tools.py b/src/spikeinterface/core/tests/test_recording_tools.py index 7cfc4239b6..5e0b77a151 100644 --- a/src/spikeinterface/core/tests/test_recording_tools.py +++ b/src/spikeinterface/core/tests/test_recording_tools.py @@ -149,16 +149,18 @@ def test_write_memory_recording(): recording = recording.save() # write with loop - write_memory_recording(recording, dtype=None, verbose=True, n_jobs=1) + traces_list, shms = write_memory_recording(recording, dtype=None, verbose=True, n_jobs=1) - write_memory_recording(recording, dtype=None, verbose=True, n_jobs=1, chunk_memory="100k", progress_bar=True) - - if platform.system() != "Windows": - # write parrallel - write_memory_recording(recording, dtype=None, verbose=False, n_jobs=2, chunk_memory="100k") + traces_list, shms = write_memory_recording( + recording, dtype=None, verbose=True, n_jobs=1, chunk_memory="100k", progress_bar=True + ) - # write parrallel - write_memory_recording(recording, dtype=None, verbose=False, n_jobs=2, total_memory="200k", progress_bar=True) + # write parallel + traces_list, shms = write_memory_recording(recording, dtype=None, verbose=False, n_jobs=2, chunk_memory="100k") + # need to clean the buffer + del traces_list + for shm in shms: + shm.unlink() def test_get_random_data_chunks(): diff --git a/src/spikeinterface/core/zarrextractors.py b/src/spikeinterface/core/zarrextractors.py index 881f6ffede..2dff3b5d21 100644 --- a/src/spikeinterface/core/zarrextractors.py +++ b/src/spikeinterface/core/zarrextractors.py @@ -249,6 +249,8 @@ def read_zarr( extractor: ZarrExtractor The loaded extractor """ + # TODO @alessio : we should have something more explicit in our zarr format to tell which object it is. + # for the futur SortingResult we will have this 2 fields!!! root = zarr.open(str(folder_path), mode="r", storage_options=storage_options) if "channel_ids" in root.keys(): return read_zarr_recording(folder_path, storage_options=storage_options)