Skip to content

Commit

Permalink
Merge pull request #2365 from samuelgarcia/sharedmam_rec
Browse files Browse the repository at this point in the history
Create SharedmemRecording
  • Loading branch information
alejoe91 authored Jan 12, 2024
2 parents 52fffb5 + 8c7eafe commit 9c99cda
Show file tree
Hide file tree
Showing 10 changed files with 227 additions and 40 deletions.
9 changes: 8 additions & 1 deletion src/spikeinterface/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,14 @@
# main extractor from dump and cache
from .binaryrecordingextractor import BinaryRecordingExtractor, read_binary
from .npzsortingextractor import NpzSortingExtractor, read_npz_sorting
from .numpyextractors import NumpyRecording, NumpySorting, SharedMemorySorting, NumpyEvent, NumpySnippets
from .numpyextractors import (
NumpyRecording,
SharedMemoryRecording,
NumpySorting,
SharedMemorySorting,
NumpyEvent,
NumpySnippets,
)
from .zarrextractors import ZarrRecordingExtractor, ZarrSortingExtractor, read_zarr, get_default_zarr_compressor
from .binaryfolder import BinaryFolderRecording, read_binary_folder
from .sortingfolder import NumpyFolderSorting, NpzFolderSorting, read_numpy_sorting_folder, read_npz_folder
Expand Down
11 changes: 7 additions & 4 deletions src/spikeinterface/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,9 +847,10 @@ def save(self, **kwargs) -> "BaseExtractor":

save.__doc__ = save.__doc__.format(_shared_job_kwargs_doc)

def save_to_memory(self, **kwargs) -> "BaseExtractor":
# used only by recording at the moment
cached = self._save(**kwargs)
def save_to_memory(self, sharedmem=True, **save_kwargs) -> "BaseExtractor":
save_kwargs.pop("format", None)

cached = self._save(format="memory", sharedmem=sharedmem, **save_kwargs)
self.copy_metadata(cached)
return cached

Expand Down Expand Up @@ -978,6 +979,8 @@ def save_to_zarr(
import zarr
from .zarrextractors import read_zarr

save_kwargs.pop("format", None)

if folder is None:
cache_folder = get_global_tmp_folder()
if name is None:
Expand Down Expand Up @@ -1006,7 +1009,7 @@ def save_to_zarr(
save_kwargs["zarr_path"] = zarr_path
save_kwargs["storage_options"] = storage_options
save_kwargs["channel_chunk_size"] = channel_chunk_size
cached = self._save(verbose=verbose, **save_kwargs)
cached = self._save(format="zarr", verbose=verbose, **save_kwargs)
cached = read_zarr(zarr_path)

return cached
Expand Down
16 changes: 5 additions & 11 deletions src/spikeinterface/core/baserecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,12 +441,6 @@ def time_to_sample_index(self, time_s, segment_index=None):
return rs.time_to_sample_index(time_s)

def _save(self, format="binary", **save_kwargs):
"""
This function replaces the old CacheRecordingExtractor, but enables more engines
for caching a results. At the moment only "binary" with memmap is supported.
We plan to add other engines, such as zarr and NWB.
"""

# handle t_starts
t_starts = []
has_time_vectors = []
Expand Down Expand Up @@ -490,12 +484,12 @@ def _save(self, format="binary", **save_kwargs):
cached = BinaryFolderRecording(folder_path=folder)

elif format == "memory":
traces_list = write_memory_recording(self, dtype=None, **job_kwargs)
from .numpyextractors import NumpyRecording
if kwargs.get("sharedmem", True):
from .numpyextractors import SharedMemoryRecording

cached = NumpyRecording(
traces_list, self.get_sampling_frequency(), t_starts=t_starts, channel_ids=self.channel_ids
)
cached = SharedMemoryRecording.from_recording(self, **job_kwargs)
else:
cached = NumpyRecording.from_recording(self, **job_kwargs)

elif format == "zarr":
from .zarrextractors import ZarrRecordingExtractor
Expand Down
9 changes: 7 additions & 2 deletions src/spikeinterface/core/basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,9 +264,14 @@ def _save(self, format="numpy_folder", **save_kwargs):
cached.register_recording(self._recording)

elif format == "memory":
from .numpyextractors import NumpySorting
if save_kwargs.get("sharedmem", True):
from .numpyextractors import SharedMemorySorting

cached = NumpySorting.from_sorting(self)
cached = SharedMemorySorting.from_sorting(self)
else:
from .numpyextractors import NumpySorting

cached = NumpySorting.from_sorting(self)
else:
raise ValueError(f"format {format} not supported")
return cached
Expand Down
133 changes: 132 additions & 1 deletion src/spikeinterface/core/numpyextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
)
from .basesorting import minimum_spike_dtype
from .core_tools import make_shared_array

from .recording_tools import write_memory_recording
from multiprocessing.shared_memory import SharedMemory


from typing import Union


Expand Down Expand Up @@ -83,6 +84,25 @@ def __init__(self, traces_list, sampling_frequency, t_starts=None, channel_ids=N
"sampling_frequency": sampling_frequency,
}

@staticmethod
def from_recording(source_recording, **job_kwargs):
traces_list, shms = write_memory_recording(source_recording, dtype=None, **job_kwargs)
if shms[0] is not None:
# if the computation was done in parallel then traces_list is shared array
# this can lead to problem
# we need to copy back to a standard numpy array and unlink the shared buffer
traces_list = [np.array(traces, copy=True) for traces in traces_list]
for shm in shms:
shm.close()
shm.unlink()
# TODO later : propagte t_starts ?
recording = NumpyRecording(
traces_list,
source_recording.get_sampling_frequency(),
t_starts=None,
channel_ids=source_recording.channel_ids,
)


class NumpyRecordingSegment(BaseRecordingSegment):
def __init__(self, traces, sampling_frequency, t_start):
Expand All @@ -101,6 +121,117 @@ def get_traces(self, start_frame, end_frame, channel_indices):
return traces


class SharedMemoryRecording(BaseRecording):
"""
In memory recording with shared memmory buffer.
Parameters
----------
shm_names: list
List of sharedmem names for each segment
shape_list: list
List of shape of sharedmem buffer for each segment
The first dimension is the number of samples, the second is the number of channels.
Note that the number of channels must be the same for all segments
sampling_frequency: float
The sampling frequency in Hz
t_starts: None or list of float
Times in seconds of the first sample for each segment
channel_ids: list
An optional list of channel_ids. If None, linear channels are assumed
main_shm_owner: bool, default: True
If True, the main instance will unlink the sharedmem buffer when deleted
"""

extractor_name = "SharedMemory"
mode = "memory"
name = "SharedMemory"

def __init__(
self, shm_names, shape_list, dtype, sampling_frequency, channel_ids=None, t_starts=None, main_shm_owner=True
):
assert len(shape_list) == len(shm_names), "Each shm_name in `shm_names` must have a shape in `shape_list`"
assert all(shape_list[0][1] == shape[1] for shape in shape_list)

# create traces from sharedmem names
self.shms = []
traces_list = []
for shm_name, shape in zip(shm_names, shape_list):
shm = SharedMemory(shm_name, create=False)
self.shms.append(shm)
traces = np.ndarray(shape=shape, dtype=dtype, buffer=shm.buf)
traces_list.append(traces)

if channel_ids is None:
channel_ids = np.arange(traces_list[0].shape[1])
else:
channel_ids = np.asarray(channel_ids)
assert channel_ids.size == traces_list[0].shape[1]
BaseRecording.__init__(self, sampling_frequency, channel_ids, dtype)

if t_starts is not None:
assert len(t_starts) == len(traces_list), "t_starts must be a list of same size as traces_list"
t_starts = [float(t_start) for t_start in t_starts]

self._serializability["memory"] = True
self._serializability["json"] = False
self._serializability["pickle"] = False

# this is important so that the main owner can unlink the share mem buffer :
self.main_shm_owner = main_shm_owner

for (
i,
traces,
) in enumerate(traces_list):
if t_starts is None:
t_start = None
else:
t_start = t_starts[i]
rec_segment = NumpyRecordingSegment(traces, sampling_frequency, t_start)

self.add_recording_segment(rec_segment)

self._kwargs = {
"shm_names": shm_names,
"shape_list": shape_list,
"dtype": dtype,
"sampling_frequency": sampling_frequency,
"channel_ids": channel_ids,
"t_starts": t_starts,
# this is important so that clone of this will not try to unlink the share mem buffer :
"main_shm_owner": False,
}

def __del__(self):
self._recording_segments = []
for shm in self.shms:
shm.close()
if self.main_shm_owner:
shm.unlink()

@staticmethod
def from_recording(source_recording, **job_kwargs):
traces_list, shms = write_memory_recording(source_recording, buffer_type="sharedmem", **job_kwargs)

# TODO later : propagte t_starts ?

recording = SharedMemoryRecording(
shm_names=[shm.name for shm in shms],
shape_list=[traces.shape for traces in traces_list],
dtype=source_recording.dtype,
sampling_frequency=source_recording.sampling_frequency,
channel_ids=source_recording.channel_ids,
t_starts=None,
main_shm_owner=True,
)

for shm in shms:
# the sharedmem are handle by the new SharedMemoryRecording
shm.close()
return recording


class NumpySorting(BaseSorting):
"""
In memory sorting object.
Expand Down
16 changes: 13 additions & 3 deletions src/spikeinterface/core/recording_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def _write_memory_chunk(segment_index, start_frame, end_frame, worker_ctx):
arr[start_frame:end_frame, :] = traces


def write_memory_recording(recording, dtype=None, verbose=False, auto_cast_uint=True, **job_kwargs):
def write_memory_recording(recording, dtype=None, verbose=False, auto_cast_uint=True, buffer_type="auto", **job_kwargs):
"""
Save the traces into numpy arrays (memory).
try to use the SharedMemory introduce in py3.8 if n_jobs > 1
Expand All @@ -284,6 +284,7 @@ def write_memory_recording(recording, dtype=None, verbose=False, auto_cast_uint=
If True, output is verbose (when chunks are used)
auto_cast_uint: bool, default: True
If True, unsigned integers are automatically cast to int if the specified dtype is signed
buffer_type: "auto" | "numpy" | "sharedmem"
{}
Returns
Expand All @@ -302,19 +303,28 @@ def write_memory_recording(recording, dtype=None, verbose=False, auto_cast_uint=
# create sharedmmep
arrays = []
shm_names = []
shms = []
shapes = []

n_jobs = ensure_n_jobs(recording, n_jobs=job_kwargs.get("n_jobs", 1))
if buffer_type == "auto":
if n_jobs > 1:
buffer_type = "sharedmem"
else:
buffer_type = "numpy"

for segment_index in range(recording.get_num_segments()):
num_frames = recording.get_num_samples(segment_index)
num_channels = recording.get_num_channels()
shape = (num_frames, num_channels)
shapes.append(shape)
if n_jobs > 1:
if buffer_type == "sharedmem":
arr, shm = make_shared_array(shape, dtype)
shm_names.append(shm.name)
shms.append(shm)
else:
arr = np.zeros(shape, dtype=dtype)
shms.append(None)
arrays.append(arr)

# use executor (loop or workers)
Expand All @@ -330,7 +340,7 @@ def write_memory_recording(recording, dtype=None, verbose=False, auto_cast_uint=
)
executor.run()

return arrays
return arrays, shms


write_memory_recording.__doc__ = write_memory_recording.__doc__.format(_shared_job_kwargs_doc)
Expand Down
11 changes: 8 additions & 3 deletions src/spikeinterface/core/tests/test_baserecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,15 +162,20 @@ def test_BaseRecording():
rec3 = BaseExtractor.load(cache_folder / "simple_recording")

# cache to memory
rec4 = rec3.save(format="memory")

rec4 = rec3.save(format="memory", shared=False)
traces4 = rec4.get_traces(segment_index=0)
traces = rec.get_traces(segment_index=0)
assert np.array_equal(traces4, traces)

# cache to sharedmemory
rec5 = rec3.save(format="memory", shared=True)
traces5 = rec5.get_traces(segment_index=0)
traces = rec.get_traces(segment_index=0)
assert np.array_equal(traces5, traces)

# cache joblib several jobs
folder = cache_folder / "simple_recording2"
rec2 = rec.save(folder=folder, chunk_size=10, n_jobs=4)
rec2 = rec.save(format="binary", folder=folder, chunk_size=10, n_jobs=4)
traces2 = rec2.get_traces(segment_index=0)

# set/get Probe only 2 channels
Expand Down
Loading

0 comments on commit 9c99cda

Please sign in to comment.