From a8e9924ac9d14dd7ec5f116112866846eac2e9e2 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 7 Jul 2023 10:42:07 +0200 Subject: [PATCH 01/14] Start waveforme xtarctor in one buffer --- src/spikeinterface/core/waveform_tools.py | 217 +++++++++++++++++++++- 1 file changed, 213 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index a10c209f47..a68f8cfd5f 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -257,8 +257,8 @@ def distribute_waveforms_to_buffers( inds_by_unit[unit_id] = inds # and run - func = _waveform_extractor_chunk - init_func = _init_worker_waveform_extractor + func = _worker_ditribute_buffers + init_func = _init_worker_ditribute_buffers init_args = ( recording, @@ -282,7 +282,7 @@ def distribute_waveforms_to_buffers( # used by ChunkRecordingExecutor -def _init_worker_waveform_extractor( +def _init_worker_ditribute_buffers( recording, unit_ids, spikes, wfs_arrays_info, nbefore, nafter, return_scaled, inds_by_unit, mode, sparsity_mask ): # create a local dict per worker @@ -328,7 +328,216 @@ def _init_worker_waveform_extractor( # used by ChunkRecordingExecutor -def _waveform_extractor_chunk(segment_index, start_frame, end_frame, worker_ctx): +def _worker_ditribute_buffers(segment_index, start_frame, end_frame, worker_ctx): + # recover variables of the worker + recording = worker_ctx["recording"] + unit_ids = worker_ctx["unit_ids"] + spikes = worker_ctx["spikes"] + nbefore = worker_ctx["nbefore"] + nafter = worker_ctx["nafter"] + return_scaled = worker_ctx["return_scaled"] + inds_by_unit = worker_ctx["inds_by_unit"] + sparsity_mask = worker_ctx["sparsity_mask"] + + seg_size = recording.get_num_samples(segment_index=segment_index) + + # take only spikes with the correct segment_index + # this is a slice so no copy!! + s0 = np.searchsorted(spikes["segment_index"], segment_index) + s1 = np.searchsorted(spikes["segment_index"], segment_index + 1) + in_seg_spikes = spikes[s0:s1] + + # take only spikes in range [start_frame, end_frame] + # this is a slice so no copy!! + i0 = np.searchsorted(in_seg_spikes["sample_index"], start_frame) + i1 = np.searchsorted(in_seg_spikes["sample_index"], end_frame) + if i0 != i1: + # protect from spikes on border : spike_time<0 or spike_time>seg_size + # useful only when max_spikes_per_unit is not None + # waveform will not be extracted and a zeros will be left in the memmap file + while (in_seg_spikes[i0]["sample_index"] - nbefore) < 0 and (i0 != i1): + i0 = i0 + 1 + while (in_seg_spikes[i1 - 1]["sample_index"] + nafter) > seg_size and (i0 != i1): + i1 = i1 - 1 + + # slice in absolut in spikes vector + l0 = i0 + s0 + l1 = i1 + s0 + + if l1 > l0: + start = spikes[l0]["sample_index"] - nbefore + end = spikes[l1 - 1]["sample_index"] + nafter + + # load trace in memory + traces = recording.get_traces( + start_frame=start, end_frame=end, segment_index=segment_index, return_scaled=return_scaled + ) + + for unit_ind, unit_id in enumerate(unit_ids): + # find pos + inds = inds_by_unit[unit_id] + (in_chunk_pos,) = np.nonzero((inds >= l0) & (inds < l1)) + if in_chunk_pos.size == 0: + continue + + if worker_ctx["mode"] == "memmap": + # open file in demand (and also autoclose it after) + filename = worker_ctx["wfs_arrays_info"][unit_id] + wfs = np.load(str(filename), mmap_mode="r+") + elif worker_ctx["mode"] == "shared_memory": + wfs = worker_ctx["wfs_arrays"][unit_id] + + for pos in in_chunk_pos: + sample_index = spikes[inds[pos]]["sample_index"] + wf = traces[sample_index - start - nbefore : sample_index - start + nafter, :] + + if sparsity_mask is None: + wfs[pos, :, :] = wf + else: + wfs[pos, :, :] = wf[:, sparsity_mask[unit_ind]] + + +def extract_waveforms_to_unique_buffer( + recording, + spikes, + unit_ids, + nbefore, + nafter, + mode="memmap", + return_scaled=False, + folder=None, + dtype=None, + sparsity_mask=None, + copy=False, + **job_kwargs, +): + + nsamples = nbefore + nafter + + dtype = np.dtype(dtype) + if mode == "shared_memory": + assert folder is None + else: + folder = Path(folder) + + num_spikes = spike.size + if sparsity_mask is None: + num_chans = recording.get_num_channels() + else: + num_chans = np.sum(sparsity_mask[unit_ind, :]) + shape = (num_spikes, nsamples, num_chans) + + if mode == "memmap": + filename = str(folder / f"all_waveforms.npy") + wfs_array = np.lib.format.open_memmap(filename, mode="w+", dtype=dtype, shape=shape) + wf_array_info = filename + elif mode == "shared_memory": + if num_spikes == 0: + wfs_array = np.zeros(shape, dtype=dtype) + shm = None + shm_name = None + else: + wfs_array, shm = make_shared_array(shape, dtype) + shm_name = shm.name + wf_array_info = (shm, shm_name, dtype.str, shape) + else: + raise ValueError("allocate_waveforms_buffers bad mode") + + + job_kwargs = fix_job_kwargs(job_kwargs) + + inds_by_unit = {} + for unit_ind, unit_id in enumerate(unit_ids): + (inds,) = np.nonzero(spikes["unit_index"] == unit_ind) + inds_by_unit[unit_id] = inds + + if num_spikes > 0: + # and run + func = _worker_ditribute_one_buffer + init_func = _init_worker_ditribute_buffers + + init_args = ( + recording, + unit_ids, + spikes, + wf_array_info, + nbefore, + nafter, + return_scaled, + mode, + sparsity_mask, + ) + processor = ChunkRecordingExecutor( + recording, func, init_func, init_args, job_name=f"extract waveforms {mode}", **job_kwargs + ) + processor.run() + + + # if mode == "memmap": + # return wfs_arrays + # elif mode == "shared_memory": + # if copy: + # wfs_arrays = {unit_id: arr.copy() for unit_id, arr in wfs_arrays.items()} + # # release all sharedmem buffer + # for unit_id in unit_ids: + # shm = wfs_arrays_info[unit_id][0] + # if shm is not None: + # # empty array have None + # shm.unlink() + # return wfs_arrays + # else: + # return wfs_arrays, wfs_arrays_info + + + + +def _init_worker_ditribute_one_buffer( + recording, unit_ids, spikes, wf_array_info, nbefore, nafter, return_scaled, mode, sparsity_mask +): + + worker_ctx = {} + worker_ctx["recording"] = recording + worker_ctx["wf_array_info"] = wf_array_info + + if mode == "memmap": + filename = wf_array_info + wfs = np.load(str(filename), mmap_mode="r+") + + # in memmap mode we have the "too many open file" problem with linux + # memmap file will be open on demand and not globally per worker + worker_ctx["wf_array_info"] = wf_array_info + elif mode == "shared_memory": + from multiprocessing.shared_memory import SharedMemory + + wfs_arrays = {} + shms = {} + for unit_id, (shm, shm_name, dtype, shape) in wfs_arrays_info.items(): + if shm_name is None: + arr = np.zeros(shape=shape, dtype=dtype) + else: + shm = SharedMemory(shm_name) + arr = np.ndarray(shape=shape, dtype=dtype, buffer=shm.buf) + wfs_arrays[unit_id] = arr + # we need a reference to all sham otherwise we get segment fault!!! + shms[unit_id] = shm + worker_ctx["shms"] = shms + worker_ctx["wfs_arrays"] = wfs_arrays + + worker_ctx["unit_ids"] = unit_ids + worker_ctx["spikes"] = spikes + + worker_ctx["nbefore"] = nbefore + worker_ctx["nafter"] = nafter + worker_ctx["return_scaled"] = return_scaled + worker_ctx["inds_by_unit"] = inds_by_unit + worker_ctx["sparsity_mask"] = sparsity_mask + worker_ctx["mode"] = mode + + return worker_ctx + + +# used by ChunkRecordingExecutor +def _worker_ditribute_one_buffer(segment_index, start_frame, end_frame, worker_ctx): # recover variables of the worker recording = worker_ctx["recording"] unit_ids = worker_ctx["unit_ids"] From bb4457505e22fb5074cf17060391e4e5ce91a80f Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 7 Jul 2023 14:37:15 +0200 Subject: [PATCH 02/14] wip waveform tools speedup --- .../core/tests/test_waveform_tools.py | 29 +++- src/spikeinterface/core/waveform_tools.py | 124 ++++++++---------- 2 files changed, 77 insertions(+), 76 deletions(-) diff --git a/src/spikeinterface/core/tests/test_waveform_tools.py b/src/spikeinterface/core/tests/test_waveform_tools.py index a896ff9c8b..457be5cba4 100644 --- a/src/spikeinterface/core/tests/test_waveform_tools.py +++ b/src/spikeinterface/core/tests/test_waveform_tools.py @@ -7,7 +7,7 @@ from spikeinterface.core import generate_recording, generate_sorting from spikeinterface.core.waveform_tools import ( - extract_waveforms_to_buffers, + extract_waveforms_to_buffers, extract_waveforms_to_unique_buffer, ) # allocate_waveforms_buffers, distribute_waveforms_to_buffers @@ -64,8 +64,7 @@ def test_waveform_tools(): if wf_folder.is_dir(): shutil.rmtree(wf_folder) wf_folder.mkdir() - # wfs_arrays, wfs_arrays_info = allocate_waveforms_buffers(recording, spikes, unit_ids, nbefore, nafter, mode='memmap', folder=wf_folder, dtype=dtype) - # distribute_waveforms_to_buffers(recording, spikes, unit_ids, wfs_arrays_info, nbefore, nafter, return_scaled, **job_kwargs) + wfs_arrays = extract_waveforms_to_buffers( recording, spikes, @@ -84,8 +83,27 @@ def test_waveform_tools(): wf = wfs_arrays[unit_id] assert wf.shape[0] == np.sum(spikes["unit_index"] == unit_ind) list_wfs.append({unit_id: wfs_arrays[unit_id].copy() for unit_id in unit_ids}) + + wfs_array = extract_waveforms_to_unique_buffer( + recording, + spikes, + unit_ids, + nbefore, + nafter, + mode="memmap", + return_scaled=False, + folder=wf_folder, + dtype=dtype, + sparsity_mask=None, + copy=False, + **job_kwargs, + ) + print(wfs_array.shape) + _check_all_wf_equal(list_wfs) + + # memory if platform.system() != "Windows": # shared memory on windows is buggy... @@ -125,9 +143,6 @@ def test_waveform_tools(): sparsity_mask = np.random.randint(0, 2, size=(unit_ids.size, recording.channel_ids.size), dtype="bool") job_kwargs = {"n_jobs": 1, "chunk_size": 3000, "progress_bar": True} - # wfs_arrays, wfs_arrays_info = allocate_waveforms_buffers(recording, spikes, unit_ids, nbefore, nafter, mode='memmap', folder=wf_folder, dtype=dtype, sparsity_mask=sparsity_mask) - # distribute_waveforms_to_buffers(recording, spikes, unit_ids, wfs_arrays_info, nbefore, nafter, return_scaled, sparsity_mask=sparsity_mask, **job_kwargs) - wfs_arrays = extract_waveforms_to_buffers( recording, spikes, @@ -144,5 +159,7 @@ def test_waveform_tools(): ) + + if __name__ == "__main__": test_waveform_tools() diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index a68f8cfd5f..a7c8493381 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -214,6 +214,7 @@ def distribute_waveforms_to_buffers( return_scaled, mode="memmap", sparsity_mask=None, + job_name=None, **job_kwargs, ): """ @@ -272,9 +273,9 @@ def distribute_waveforms_to_buffers( mode, sparsity_mask, ) - processor = ChunkRecordingExecutor( - recording, func, init_func, init_args, job_name=f"extract waveforms {mode}", **job_kwargs - ) + if job_name is None: + job_name=f"extract waveforms {mode} multi buffer" + processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name=job_name, **job_kwargs) processor.run() @@ -409,6 +410,7 @@ def extract_waveforms_to_unique_buffer( dtype=None, sparsity_mask=None, copy=False, + job_name=None, **job_kwargs, ): @@ -420,11 +422,11 @@ def extract_waveforms_to_unique_buffer( else: folder = Path(folder) - num_spikes = spike.size + num_spikes = spikes.size if sparsity_mask is None: num_chans = recording.get_num_channels() else: - num_chans = np.sum(sparsity_mask[unit_ind, :]) + num_chans = max(np.sum(sparsity_mask, axis=1)) shape = (num_spikes, nsamples, num_chans) if mode == "memmap": @@ -454,7 +456,7 @@ def extract_waveforms_to_unique_buffer( if num_spikes > 0: # and run func = _worker_ditribute_one_buffer - init_func = _init_worker_ditribute_buffers + init_func = _init_worker_ditribute_one_buffer init_args = ( recording, @@ -466,27 +468,27 @@ def extract_waveforms_to_unique_buffer( return_scaled, mode, sparsity_mask, + ) - processor = ChunkRecordingExecutor( - recording, func, init_func, init_args, job_name=f"extract waveforms {mode}", **job_kwargs - ) + if job_name is None: + job_name = f"extract waveforms {mode} mono buffer" + + processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name=job_name, **job_kwargs) processor.run() - # if mode == "memmap": - # return wfs_arrays - # elif mode == "shared_memory": - # if copy: - # wfs_arrays = {unit_id: arr.copy() for unit_id, arr in wfs_arrays.items()} - # # release all sharedmem buffer - # for unit_id in unit_ids: - # shm = wfs_arrays_info[unit_id][0] - # if shm is not None: - # # empty array have None - # shm.unlink() - # return wfs_arrays - # else: - # return wfs_arrays, wfs_arrays_info + if mode == "memmap": + return wfs_array + elif mode == "shared_memory": + if copy: + wf_array_info = wf_array_info.copy() + if shm is not None: + # release all sharedmem buffer + # empty array have None + shm.unlink() + return wfs_array + else: + return wfs_array, wf_array_info @@ -501,35 +503,29 @@ def _init_worker_ditribute_one_buffer( if mode == "memmap": filename = wf_array_info - wfs = np.load(str(filename), mmap_mode="r+") - - # in memmap mode we have the "too many open file" problem with linux - # memmap file will be open on demand and not globally per worker - worker_ctx["wf_array_info"] = wf_array_info + wfs_array = np.load(str(filename), mmap_mode="r+") + worker_ctx["wfs_array"] = wfs_array elif mode == "shared_memory": from multiprocessing.shared_memory import SharedMemory - - wfs_arrays = {} - shms = {} - for unit_id, (shm, shm_name, dtype, shape) in wfs_arrays_info.items(): - if shm_name is None: - arr = np.zeros(shape=shape, dtype=dtype) - else: - shm = SharedMemory(shm_name) - arr = np.ndarray(shape=shape, dtype=dtype, buffer=shm.buf) - wfs_arrays[unit_id] = arr - # we need a reference to all sham otherwise we get segment fault!!! - shms[unit_id] = shm - worker_ctx["shms"] = shms - worker_ctx["wfs_arrays"] = wfs_arrays + shm, shm_name, dtype, shape = wf_array_info + shm = SharedMemory(shm_name) + wfs_array = np.ndarray(shape=shape, dtype=dtype, buffer=shm.buf) + worker_ctx["shm"] = shm + worker_ctx["wfs_array"] = wfs_array + + # prepare segment slices + segment_slices = [] + for segment_index in range(recording.get_num_segments()): + s0 = np.searchsorted(spikes["segment_index"], segment_index) + s1 = np.searchsorted(spikes["segment_index"], segment_index + 1) + segment_slices.append((s0, s1)) + worker_ctx["segment_slices"] = segment_slices worker_ctx["unit_ids"] = unit_ids worker_ctx["spikes"] = spikes - worker_ctx["nbefore"] = nbefore worker_ctx["nafter"] = nafter worker_ctx["return_scaled"] = return_scaled - worker_ctx["inds_by_unit"] = inds_by_unit worker_ctx["sparsity_mask"] = sparsity_mask worker_ctx["mode"] = mode @@ -541,20 +537,18 @@ def _worker_ditribute_one_buffer(segment_index, start_frame, end_frame, worker_c # recover variables of the worker recording = worker_ctx["recording"] unit_ids = worker_ctx["unit_ids"] + segment_slices = worker_ctx["segment_slices"] spikes = worker_ctx["spikes"] nbefore = worker_ctx["nbefore"] nafter = worker_ctx["nafter"] return_scaled = worker_ctx["return_scaled"] - inds_by_unit = worker_ctx["inds_by_unit"] sparsity_mask = worker_ctx["sparsity_mask"] + wfs_array = worker_ctx["wfs_array"] seg_size = recording.get_num_samples(segment_index=segment_index) - # take only spikes with the correct segment_index - # this is a slice so no copy!! - s0 = np.searchsorted(spikes["segment_index"], segment_index) - s1 = np.searchsorted(spikes["segment_index"], segment_index + 1) - in_seg_spikes = spikes[s0:s1] + s0, s1 = segment_slices[segment_index] + in_seg_spikes = spikes[s0: s1] # take only spikes in range [start_frame, end_frame] # this is a slice so no copy!! @@ -582,28 +576,18 @@ def _worker_ditribute_one_buffer(segment_index, start_frame, end_frame, worker_c start_frame=start, end_frame=end, segment_index=segment_index, return_scaled=return_scaled ) - for unit_ind, unit_id in enumerate(unit_ids): - # find pos - inds = inds_by_unit[unit_id] - (in_chunk_pos,) = np.nonzero((inds >= l0) & (inds < l1)) - if in_chunk_pos.size == 0: - continue + for spike_ind in range(l0, l1): + sample_index = spikes[spike_ind]["sample_index"] + unit_index = spikes[spike_ind]["unit_index"] + wf = traces[sample_index - start - nbefore : sample_index - start + nafter, :] - if worker_ctx["mode"] == "memmap": - # open file in demand (and also autoclose it after) - filename = worker_ctx["wfs_arrays_info"][unit_id] - wfs = np.load(str(filename), mmap_mode="r+") - elif worker_ctx["mode"] == "shared_memory": - wfs = worker_ctx["wfs_arrays"][unit_id] - - for pos in in_chunk_pos: - sample_index = spikes[inds[pos]]["sample_index"] - wf = traces[sample_index - start - nbefore : sample_index - start + nafter, :] + if sparsity_mask is None: + wfs_array[spike_ind, :, :] = None + else: + mask = sparsity_mask[unit_index, :] + wf = wf[:, mask] + wfs_array[spike_ind, :, :wf.shape[1]] = wf - if sparsity_mask is None: - wfs[pos, :, :] = wf - else: - wfs[pos, :, :] = wf[:, sparsity_mask[unit_ind]] def has_exceeding_spikes(recording, sorting): From 5539bdb0028b2bdcebb6b4e6d201f0579770f3e1 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 7 Jul 2023 17:02:02 +0200 Subject: [PATCH 03/14] extract_waveforms_to_unique_buffer is more or less OK --- .../core/tests/test_waveform_tools.py | 181 ++++++++---------- src/spikeinterface/core/waveform_tools.py | 59 +++--- 2 files changed, 115 insertions(+), 125 deletions(-) diff --git a/src/spikeinterface/core/tests/test_waveform_tools.py b/src/spikeinterface/core/tests/test_waveform_tools.py index 457be5cba4..ef75180898 100644 --- a/src/spikeinterface/core/tests/test_waveform_tools.py +++ b/src/spikeinterface/core/tests/test_waveform_tools.py @@ -7,8 +7,8 @@ from spikeinterface.core import generate_recording, generate_sorting from spikeinterface.core.waveform_tools import ( - extract_waveforms_to_buffers, extract_waveforms_to_unique_buffer, -) # allocate_waveforms_buffers, distribute_waveforms_to_buffers + extract_waveforms_to_buffers, extract_waveforms_to_unique_buffer, split_waveforms_by_units, +) if hasattr(pytest, "global_test_folder"): @@ -21,6 +21,10 @@ def _check_all_wf_equal(list_wfs_arrays): wfs_arrays0 = list_wfs_arrays[0] for i, wfs_arrays in enumerate(list_wfs_arrays): for unit_id in wfs_arrays.keys(): + print() + print('*'*10) + print(wfs_arrays[unit_id].shape) + print(wfs_arrays0[unit_id].shape) assert np.array_equal(wfs_arrays[unit_id], wfs_arrays0[unit_id]) @@ -52,111 +56,86 @@ def test_waveform_tools(): unit_ids = sorting.unit_ids some_job_kwargs = [ - {}, {"n_jobs": 1, "chunk_size": 3000, "progress_bar": True}, {"n_jobs": 2, "chunk_size": 3000, "progress_bar": True}, ] + some_modes = [ + {"mode" : "memmap"}, + ] + if platform.system() != "Windows": + # shared memory on windows is buggy... + some_modes.append({"mode" : "shared_memory", }) + + some_sparsity = [ + dict(sparsity_mask=None), + dict(sparsity_mask = np.random.randint(0, 2, size=(unit_ids.size, recording.channel_ids.size), dtype="bool")), + ] + # memmap mode - list_wfs = [] + list_wfs_dense = [] + list_wfs_sparse = [] for j, job_kwargs in enumerate(some_job_kwargs): - wf_folder = cache_folder / f"test_waveform_tools_{j}" - if wf_folder.is_dir(): - shutil.rmtree(wf_folder) - wf_folder.mkdir() - - wfs_arrays = extract_waveforms_to_buffers( - recording, - spikes, - unit_ids, - nbefore, - nafter, - mode="memmap", - return_scaled=False, - folder=wf_folder, - dtype=dtype, - sparsity_mask=None, - copy=False, - **job_kwargs, - ) - for unit_ind, unit_id in enumerate(unit_ids): - wf = wfs_arrays[unit_id] - assert wf.shape[0] == np.sum(spikes["unit_index"] == unit_ind) - list_wfs.append({unit_id: wfs_arrays[unit_id].copy() for unit_id in unit_ids}) - - wfs_array = extract_waveforms_to_unique_buffer( - recording, - spikes, - unit_ids, - nbefore, - nafter, - mode="memmap", - return_scaled=False, - folder=wf_folder, - dtype=dtype, - sparsity_mask=None, - copy=False, - **job_kwargs, - ) - print(wfs_array.shape) - - _check_all_wf_equal(list_wfs) - - - - # memory - if platform.system() != "Windows": - # shared memory on windows is buggy... - list_wfs = [] - for job_kwargs in some_job_kwargs: - # wfs_arrays, wfs_arrays_info = allocate_waveforms_buffers(recording, spikes, unit_ids, nbefore, nafter, mode='shared_memory', folder=None, dtype=dtype) - # distribute_waveforms_to_buffers(recording, spikes, unit_ids, wfs_arrays_info, nbefore, nafter, return_scaled, mode='shared_memory', **job_kwargs) - wfs_arrays = extract_waveforms_to_buffers( - recording, - spikes, - unit_ids, - nbefore, - nafter, - mode="shared_memory", - return_scaled=False, - folder=None, - dtype=dtype, - sparsity_mask=None, - copy=True, - **job_kwargs, - ) - for unit_ind, unit_id in enumerate(unit_ids): - wf = wfs_arrays[unit_id] - assert wf.shape[0] == np.sum(spikes["unit_index"] == unit_ind) - list_wfs.append({unit_id: wfs_arrays[unit_id].copy() for unit_id in unit_ids}) - # to avoid warning we need to first destroy arrays then sharedmemm object - # del wfs_arrays - # del wfs_arrays_info - _check_all_wf_equal(list_wfs) - - # with sparsity - wf_folder = cache_folder / "test_waveform_tools_sparse" - if wf_folder.is_dir(): - shutil.rmtree(wf_folder) - wf_folder.mkdir() - - sparsity_mask = np.random.randint(0, 2, size=(unit_ids.size, recording.channel_ids.size), dtype="bool") - job_kwargs = {"n_jobs": 1, "chunk_size": 3000, "progress_bar": True} - - wfs_arrays = extract_waveforms_to_buffers( - recording, - spikes, - unit_ids, - nbefore, - nafter, - mode="memmap", - return_scaled=False, - folder=wf_folder, - dtype=dtype, - sparsity_mask=sparsity_mask, - copy=False, - **job_kwargs, - ) + for k, mode_kwargs in enumerate(some_modes): + for l, sparsity_kwargs in enumerate(some_sparsity): + + print() + print(job_kwargs, mode_kwargs, 'sparse=', sparsity_kwargs['sparsity_mask'] is None) + + if mode_kwargs["mode"] == "memmap": + wf_folder = cache_folder / f"test_waveform_tools_{j}_{k}_{l}" + if wf_folder.is_dir(): + shutil.rmtree(wf_folder) + wf_folder.mkdir() + mode_kwargs_ = dict(**mode_kwargs, folder=wf_folder) + else: + mode_kwargs_ = mode_kwargs + + wfs_arrays = extract_waveforms_to_buffers( + recording, + spikes, + unit_ids, + nbefore, + nafter, + return_scaled=False, + dtype=dtype, + copy=True, + **sparsity_kwargs, + **mode_kwargs_, + **job_kwargs, + ) + for unit_ind, unit_id in enumerate(unit_ids): + wf = wfs_arrays[unit_id] + assert wf.shape[0] == np.sum(spikes["unit_index"] == unit_ind) + + if sparsity_kwargs['sparsity_mask'] is None: + list_wfs_dense.append(wfs_arrays) + else: + list_wfs_sparse.append(wfs_arrays) + + + all_waveforms = extract_waveforms_to_unique_buffer( + recording, + spikes, + unit_ids, + nbefore, + nafter, + return_scaled=False, + dtype=dtype, + copy=True, + **sparsity_kwargs, + **mode_kwargs_, + **job_kwargs, + ) + wfs_arrays = split_waveforms_by_units(unit_ids, spikes, all_waveforms, sparsity_mask=sparsity_kwargs['sparsity_mask']) + if sparsity_kwargs['sparsity_mask'] is None: + list_wfs_dense.append(wfs_arrays) + else: + list_wfs_sparse.append(wfs_arrays) + + _check_all_wf_equal(list_wfs_dense) + _check_all_wf_equal(list_wfs_sparse) + diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index a7c8493381..1a2361ec97 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -189,7 +189,7 @@ def allocate_waveforms_buffers( wfs_arrays[unit_id] = arr wfs_arrays_info[unit_id] = filename elif mode == "shared_memory": - if n_spikes == 0: + if n_spikes == 0 or num_chans == 0: arr = np.zeros(shape, dtype=dtype) shm = None shm_name = None @@ -431,15 +431,15 @@ def extract_waveforms_to_unique_buffer( if mode == "memmap": filename = str(folder / f"all_waveforms.npy") - wfs_array = np.lib.format.open_memmap(filename, mode="w+", dtype=dtype, shape=shape) + all_waveforms = np.lib.format.open_memmap(filename, mode="w+", dtype=dtype, shape=shape) wf_array_info = filename elif mode == "shared_memory": - if num_spikes == 0: - wfs_array = np.zeros(shape, dtype=dtype) + if num_spikes == 0 or num_chans == 0: + all_waveforms = np.zeros(shape, dtype=dtype) shm = None shm_name = None else: - wfs_array, shm = make_shared_array(shape, dtype) + all_waveforms, shm = make_shared_array(shape, dtype) shm_name = shm.name wf_array_info = (shm, shm_name, dtype.str, shape) else: @@ -478,17 +478,16 @@ def extract_waveforms_to_unique_buffer( if mode == "memmap": - return wfs_array + return all_waveforms elif mode == "shared_memory": if copy: - wf_array_info = wf_array_info.copy() if shm is not None: # release all sharedmem buffer # empty array have None shm.unlink() - return wfs_array + return all_waveforms.copy() else: - return wfs_array, wf_array_info + return all_waveforms, wf_array_info @@ -500,18 +499,25 @@ def _init_worker_ditribute_one_buffer( worker_ctx = {} worker_ctx["recording"] = recording worker_ctx["wf_array_info"] = wf_array_info + worker_ctx["unit_ids"] = unit_ids + worker_ctx["spikes"] = spikes + worker_ctx["nbefore"] = nbefore + worker_ctx["nafter"] = nafter + worker_ctx["return_scaled"] = return_scaled + worker_ctx["sparsity_mask"] = sparsity_mask + worker_ctx["mode"] = mode if mode == "memmap": filename = wf_array_info - wfs_array = np.load(str(filename), mmap_mode="r+") - worker_ctx["wfs_array"] = wfs_array + all_waveforms = np.load(str(filename), mmap_mode="r+") + worker_ctx["all_waveforms"] = all_waveforms elif mode == "shared_memory": from multiprocessing.shared_memory import SharedMemory shm, shm_name, dtype, shape = wf_array_info shm = SharedMemory(shm_name) - wfs_array = np.ndarray(shape=shape, dtype=dtype, buffer=shm.buf) + all_waveforms = np.ndarray(shape=shape, dtype=dtype, buffer=shm.buf) worker_ctx["shm"] = shm - worker_ctx["wfs_array"] = wfs_array + worker_ctx["all_waveforms"] = all_waveforms # prepare segment slices segment_slices = [] @@ -521,14 +527,6 @@ def _init_worker_ditribute_one_buffer( segment_slices.append((s0, s1)) worker_ctx["segment_slices"] = segment_slices - worker_ctx["unit_ids"] = unit_ids - worker_ctx["spikes"] = spikes - worker_ctx["nbefore"] = nbefore - worker_ctx["nafter"] = nafter - worker_ctx["return_scaled"] = return_scaled - worker_ctx["sparsity_mask"] = sparsity_mask - worker_ctx["mode"] = mode - return worker_ctx @@ -543,7 +541,7 @@ def _worker_ditribute_one_buffer(segment_index, start_frame, end_frame, worker_c nafter = worker_ctx["nafter"] return_scaled = worker_ctx["return_scaled"] sparsity_mask = worker_ctx["sparsity_mask"] - wfs_array = worker_ctx["wfs_array"] + all_waveforms = worker_ctx["all_waveforms"] seg_size = recording.get_num_samples(segment_index=segment_index) @@ -582,12 +580,25 @@ def _worker_ditribute_one_buffer(segment_index, start_frame, end_frame, worker_c wf = traces[sample_index - start - nbefore : sample_index - start + nafter, :] if sparsity_mask is None: - wfs_array[spike_ind, :, :] = None + all_waveforms[spike_ind, :, :] = wf else: mask = sparsity_mask[unit_index, :] wf = wf[:, mask] - wfs_array[spike_ind, :, :wf.shape[1]] = wf + all_waveforms[spike_ind, :, :wf.shape[1]] = wf + + +def split_waveforms_by_units(unit_ids, spikes, all_waveforms, sparsity_mask=None): + waveform_by_units = {} + for unit_index, unit_id in enumerate(unit_ids): + mask = spikes["unit_index"] == unit_index + if sparsity_mask is not None: + chan_mask = sparsity_mask[unit_index, :] + num_chans = np.sum(chan_mask) + waveform_by_units[unit_id] = all_waveforms[mask, :, :][:, :, :num_chans] + else: + waveform_by_units[unit_id] = all_waveforms[mask, :, :] + return waveform_by_units def has_exceeding_spikes(recording, sorting): From ec86e2987414dc060ff657531dd3be3b475bf49b Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 7 Jul 2023 17:02:34 +0200 Subject: [PATCH 04/14] clean --- src/spikeinterface/core/tests/test_waveform_tools.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/core/tests/test_waveform_tools.py b/src/spikeinterface/core/tests/test_waveform_tools.py index ef75180898..c86aa6d5d7 100644 --- a/src/spikeinterface/core/tests/test_waveform_tools.py +++ b/src/spikeinterface/core/tests/test_waveform_tools.py @@ -21,10 +21,6 @@ def _check_all_wf_equal(list_wfs_arrays): wfs_arrays0 = list_wfs_arrays[0] for i, wfs_arrays in enumerate(list_wfs_arrays): for unit_id in wfs_arrays.keys(): - print() - print('*'*10) - print(wfs_arrays[unit_id].shape) - print(wfs_arrays0[unit_id].shape) assert np.array_equal(wfs_arrays[unit_id], wfs_arrays0[unit_id]) @@ -79,8 +75,8 @@ def test_waveform_tools(): for k, mode_kwargs in enumerate(some_modes): for l, sparsity_kwargs in enumerate(some_sparsity): - print() - print(job_kwargs, mode_kwargs, 'sparse=', sparsity_kwargs['sparsity_mask'] is None) + # print() + # print(job_kwargs, mode_kwargs, 'sparse=', sparsity_kwargs['sparsity_mask'] is None) if mode_kwargs["mode"] == "memmap": wf_folder = cache_folder / f"test_waveform_tools_{j}_{k}_{l}" From da729088eb539991dc315c66a634ac23cc0f136f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 7 Jul 2023 15:03:10 +0000 Subject: [PATCH 05/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../core/tests/test_waveform_tools.py | 32 ++++++++++--------- src/spikeinterface/core/waveform_tools.py | 14 +++----- 2 files changed, 21 insertions(+), 25 deletions(-) diff --git a/src/spikeinterface/core/tests/test_waveform_tools.py b/src/spikeinterface/core/tests/test_waveform_tools.py index c86aa6d5d7..9a51a10ee2 100644 --- a/src/spikeinterface/core/tests/test_waveform_tools.py +++ b/src/spikeinterface/core/tests/test_waveform_tools.py @@ -7,7 +7,9 @@ from spikeinterface.core import generate_recording, generate_sorting from spikeinterface.core.waveform_tools import ( - extract_waveforms_to_buffers, extract_waveforms_to_unique_buffer, split_waveforms_by_units, + extract_waveforms_to_buffers, + extract_waveforms_to_unique_buffer, + split_waveforms_by_units, ) @@ -56,17 +58,20 @@ def test_waveform_tools(): {"n_jobs": 2, "chunk_size": 3000, "progress_bar": True}, ] some_modes = [ - {"mode" : "memmap"}, + {"mode": "memmap"}, ] if platform.system() != "Windows": # shared memory on windows is buggy... - some_modes.append({"mode" : "shared_memory", }) + some_modes.append( + { + "mode": "shared_memory", + } + ) some_sparsity = [ dict(sparsity_mask=None), - dict(sparsity_mask = np.random.randint(0, 2, size=(unit_ids.size, recording.channel_ids.size), dtype="bool")), + dict(sparsity_mask=np.random.randint(0, 2, size=(unit_ids.size, recording.channel_ids.size), dtype="bool")), ] - # memmap mode list_wfs_dense = [] @@ -74,7 +79,6 @@ def test_waveform_tools(): for j, job_kwargs in enumerate(some_job_kwargs): for k, mode_kwargs in enumerate(some_modes): for l, sparsity_kwargs in enumerate(some_sparsity): - # print() # print(job_kwargs, mode_kwargs, 'sparse=', sparsity_kwargs['sparsity_mask'] is None) @@ -86,7 +90,7 @@ def test_waveform_tools(): mode_kwargs_ = dict(**mode_kwargs, folder=wf_folder) else: mode_kwargs_ = mode_kwargs - + wfs_arrays = extract_waveforms_to_buffers( recording, spikes, @@ -103,13 +107,12 @@ def test_waveform_tools(): for unit_ind, unit_id in enumerate(unit_ids): wf = wfs_arrays[unit_id] assert wf.shape[0] == np.sum(spikes["unit_index"] == unit_ind) - - if sparsity_kwargs['sparsity_mask'] is None: + + if sparsity_kwargs["sparsity_mask"] is None: list_wfs_dense.append(wfs_arrays) else: list_wfs_sparse.append(wfs_arrays) - all_waveforms = extract_waveforms_to_unique_buffer( recording, spikes, @@ -123,8 +126,10 @@ def test_waveform_tools(): **mode_kwargs_, **job_kwargs, ) - wfs_arrays = split_waveforms_by_units(unit_ids, spikes, all_waveforms, sparsity_mask=sparsity_kwargs['sparsity_mask']) - if sparsity_kwargs['sparsity_mask'] is None: + wfs_arrays = split_waveforms_by_units( + unit_ids, spikes, all_waveforms, sparsity_mask=sparsity_kwargs["sparsity_mask"] + ) + if sparsity_kwargs["sparsity_mask"] is None: list_wfs_dense.append(wfs_arrays) else: list_wfs_sparse.append(wfs_arrays) @@ -133,8 +138,5 @@ def test_waveform_tools(): _check_all_wf_equal(list_wfs_sparse) - - - if __name__ == "__main__": test_waveform_tools() diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index 1a2361ec97..e6f7e944cc 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -274,7 +274,7 @@ def distribute_waveforms_to_buffers( sparsity_mask, ) if job_name is None: - job_name=f"extract waveforms {mode} multi buffer" + job_name = f"extract waveforms {mode} multi buffer" processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name=job_name, **job_kwargs) processor.run() @@ -413,7 +413,6 @@ def extract_waveforms_to_unique_buffer( job_name=None, **job_kwargs, ): - nsamples = nbefore + nafter dtype = np.dtype(dtype) @@ -445,7 +444,6 @@ def extract_waveforms_to_unique_buffer( else: raise ValueError("allocate_waveforms_buffers bad mode") - job_kwargs = fix_job_kwargs(job_kwargs) inds_by_unit = {} @@ -468,7 +466,6 @@ def extract_waveforms_to_unique_buffer( return_scaled, mode, sparsity_mask, - ) if job_name is None: job_name = f"extract waveforms {mode} mono buffer" @@ -476,7 +473,6 @@ def extract_waveforms_to_unique_buffer( processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name=job_name, **job_kwargs) processor.run() - if mode == "memmap": return all_waveforms elif mode == "shared_memory": @@ -490,12 +486,9 @@ def extract_waveforms_to_unique_buffer( return all_waveforms, wf_array_info - - def _init_worker_ditribute_one_buffer( recording, unit_ids, spikes, wf_array_info, nbefore, nafter, return_scaled, mode, sparsity_mask ): - worker_ctx = {} worker_ctx["recording"] = recording worker_ctx["wf_array_info"] = wf_array_info @@ -513,6 +506,7 @@ def _init_worker_ditribute_one_buffer( worker_ctx["all_waveforms"] = all_waveforms elif mode == "shared_memory": from multiprocessing.shared_memory import SharedMemory + shm, shm_name, dtype, shape = wf_array_info shm = SharedMemory(shm_name) all_waveforms = np.ndarray(shape=shape, dtype=dtype, buffer=shm.buf) @@ -546,7 +540,7 @@ def _worker_ditribute_one_buffer(segment_index, start_frame, end_frame, worker_c seg_size = recording.get_num_samples(segment_index=segment_index) s0, s1 = segment_slices[segment_index] - in_seg_spikes = spikes[s0: s1] + in_seg_spikes = spikes[s0:s1] # take only spikes in range [start_frame, end_frame] # this is a slice so no copy!! @@ -584,7 +578,7 @@ def _worker_ditribute_one_buffer(segment_index, start_frame, end_frame, worker_c else: mask = sparsity_mask[unit_index, :] wf = wf[:, mask] - all_waveforms[spike_ind, :, :wf.shape[1]] = wf + all_waveforms[spike_ind, :, : wf.shape[1]] = wf def split_waveforms_by_units(unit_ids, spikes, all_waveforms, sparsity_mask=None): From 45036bf72c84376edeebb531f4feebeb8f0255e2 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 10 Jul 2023 14:30:34 +0200 Subject: [PATCH 06/14] wip waveforms tools single buffer --- .../core/tests/test_waveform_tools.py | 4 +- src/spikeinterface/core/waveform_tools.py | 173 +++++++++++++----- 2 files changed, 131 insertions(+), 46 deletions(-) diff --git a/src/spikeinterface/core/tests/test_waveform_tools.py b/src/spikeinterface/core/tests/test_waveform_tools.py index 9a51a10ee2..fb65e87458 100644 --- a/src/spikeinterface/core/tests/test_waveform_tools.py +++ b/src/spikeinterface/core/tests/test_waveform_tools.py @@ -8,7 +8,7 @@ from spikeinterface.core import generate_recording, generate_sorting from spikeinterface.core.waveform_tools import ( extract_waveforms_to_buffers, - extract_waveforms_to_unique_buffer, + extract_waveforms_to_single_buffer, split_waveforms_by_units, ) @@ -113,7 +113,7 @@ def test_waveform_tools(): else: list_wfs_sparse.append(wfs_arrays) - all_waveforms = extract_waveforms_to_unique_buffer( + all_waveforms = extract_waveforms_to_single_buffer( recording, spikes, unit_ids, diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index e6f7e944cc..252ea68738 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -36,7 +36,7 @@ def extract_waveforms_to_buffers( Same as calling allocate_waveforms_buffers() and then distribute_waveforms_to_buffers(). - Important note: for the "shared_memory" mode wfs_arrays_info contains reference to + Important note: for the "shared_memory" mode arrays_info contains reference to the shared memmory buffer, this variable must be reference as long as arrays as used. And this variable is also returned. To avoid this a copy to non shared memmory can be perform at the end. @@ -66,17 +66,17 @@ def extract_waveforms_to_buffers( If not None shape must be must be (len(unit_ids), len(channel_ids)) copy: bool If True (default), the output shared memory object is copied to a numpy standard array. - If copy=False then wfs_arrays_info is also return. Please keep in mind that wfs_arrays_info - need to be referenced as long as wfs_arrays will be used otherwise it will be very hard to debug. + If copy=False then arrays_info is also return. Please keep in mind that arrays_info + need to be referenced as long as waveforms_by_units will be used otherwise it will be very hard to debug. Also when copy=False the SharedMemory will need to be unlink manually {} Returns ------- - wfs_arrays: dict of arrays + waveforms_by_units: dict of arrays Arrays for all units (memmap or shared_memmep) - wfs_arrays_info: dict of info + arrays_info: dict of info Optionally return in case of shared_memory if copy=False. Dictionary to "construct" array in workers process (memmap file or sharemem info) """ @@ -89,7 +89,7 @@ def extract_waveforms_to_buffers( dtype = "float32" dtype = np.dtype(dtype) - wfs_arrays, wfs_arrays_info = allocate_waveforms_buffers( + waveforms_by_units, arrays_info = allocate_waveforms_buffers( recording, spikes, unit_ids, nbefore, nafter, mode=mode, folder=folder, dtype=dtype, sparsity_mask=sparsity_mask ) @@ -97,7 +97,7 @@ def extract_waveforms_to_buffers( recording, spikes, unit_ids, - wfs_arrays_info, + arrays_info, nbefore, nafter, return_scaled, @@ -107,19 +107,19 @@ def extract_waveforms_to_buffers( ) if mode == "memmap": - return wfs_arrays + return waveforms_by_units elif mode == "shared_memory": if copy: - wfs_arrays = {unit_id: arr.copy() for unit_id, arr in wfs_arrays.items()} + waveforms_by_units = {unit_id: arr.copy() for unit_id, arr in waveforms_by_units.items()} # release all sharedmem buffer for unit_id in unit_ids: - shm = wfs_arrays_info[unit_id][0] + shm = arrays_info[unit_id][0] if shm is not None: # empty array have None shm.unlink() - return wfs_arrays + return waveforms_by_units else: - return wfs_arrays, wfs_arrays_info + return waveforms_by_units, arrays_info extract_waveforms_to_buffers.__doc__ = extract_waveforms_to_buffers.__doc__.format(_shared_job_kwargs_doc) @@ -131,7 +131,7 @@ def allocate_waveforms_buffers( """ Allocate memmap or shared memory buffers before snippet extraction. - Important note: for the shared memory mode wfs_arrays_info contains reference to + Important note: for the shared memory mode arrays_info contains reference to the shared memmory buffer, this variable must be reference as long as arrays as used. Parameters @@ -158,9 +158,9 @@ def allocate_waveforms_buffers( Returns ------- - wfs_arrays: dict of arrays + waveforms_by_units: dict of arrays Arrays for all units (memmap or shared_memmep - wfs_arrays_info: dict of info + arrays_info: dict of info Dictionary to "construct" array in workers process (memmap file or sharemem) """ @@ -173,8 +173,8 @@ def allocate_waveforms_buffers( folder = Path(folder) # prepare buffers - wfs_arrays = {} - wfs_arrays_info = {} + waveforms_by_units = {} + arrays_info = {} for unit_ind, unit_id in enumerate(unit_ids): n_spikes = np.sum(spikes["unit_index"] == unit_ind) if sparsity_mask is None: @@ -186,8 +186,8 @@ def allocate_waveforms_buffers( if mode == "memmap": filename = str(folder / f"waveforms_{unit_id}.npy") arr = np.lib.format.open_memmap(filename, mode="w+", dtype=dtype, shape=shape) - wfs_arrays[unit_id] = arr - wfs_arrays_info[unit_id] = filename + waveforms_by_units[unit_id] = arr + arrays_info[unit_id] = filename elif mode == "shared_memory": if n_spikes == 0 or num_chans == 0: arr = np.zeros(shape, dtype=dtype) @@ -196,19 +196,19 @@ def allocate_waveforms_buffers( else: arr, shm = make_shared_array(shape, dtype) shm_name = shm.name - wfs_arrays[unit_id] = arr - wfs_arrays_info[unit_id] = (shm, shm_name, dtype.str, shape) + waveforms_by_units[unit_id] = arr + arrays_info[unit_id] = (shm, shm_name, dtype.str, shape) else: raise ValueError("allocate_waveforms_buffers bad mode") - return wfs_arrays, wfs_arrays_info + return waveforms_by_units, arrays_info def distribute_waveforms_to_buffers( recording, spikes, unit_ids, - wfs_arrays_info, + arrays_info, nbefore, nafter, return_scaled, @@ -222,7 +222,7 @@ def distribute_waveforms_to_buffers( Buffers must be pre-allocated with the `allocate_waveforms_buffers()` function. - Important note, for "shared_memory" mode wfs_arrays_info contain reference to + Important note, for "shared_memory" mode arrays_info contain reference to the shared memmory buffer, this variable must be reference as long as arrays as used. Parameters @@ -234,7 +234,7 @@ def distribute_waveforms_to_buffers( This vector can be spikes = Sorting.to_spike_vector() unit_ids: list ot numpy List of unit_ids - wfs_arrays_info: dict + arrays_info: dict Dictionary to "construct" array in workers process (memmap file or sharemem) nbefore: int N samples before spike @@ -265,7 +265,7 @@ def distribute_waveforms_to_buffers( recording, unit_ids, spikes, - wfs_arrays_info, + arrays_info, nbefore, nafter, return_scaled, @@ -284,7 +284,7 @@ def distribute_waveforms_to_buffers( # used by ChunkRecordingExecutor def _init_worker_ditribute_buffers( - recording, unit_ids, spikes, wfs_arrays_info, nbefore, nafter, return_scaled, inds_by_unit, mode, sparsity_mask + recording, unit_ids, spikes, arrays_info, nbefore, nafter, return_scaled, inds_by_unit, mode, sparsity_mask ): # create a local dict per worker worker_ctx = {} @@ -297,23 +297,23 @@ def _init_worker_ditribute_buffers( if mode == "memmap": # in memmap mode we have the "too many open file" problem with linux # memmap file will be open on demand and not globally per worker - worker_ctx["wfs_arrays_info"] = wfs_arrays_info + worker_ctx["arrays_info"] = arrays_info elif mode == "shared_memory": from multiprocessing.shared_memory import SharedMemory - wfs_arrays = {} + waveforms_by_units = {} shms = {} - for unit_id, (shm, shm_name, dtype, shape) in wfs_arrays_info.items(): + for unit_id, (shm, shm_name, dtype, shape) in arrays_info.items(): if shm_name is None: arr = np.zeros(shape=shape, dtype=dtype) else: shm = SharedMemory(shm_name) arr = np.ndarray(shape=shape, dtype=dtype, buffer=shm.buf) - wfs_arrays[unit_id] = arr + waveforms_by_units[unit_id] = arr # we need a reference to all sham otherwise we get segment fault!!! shms[unit_id] = shm worker_ctx["shms"] = shms - worker_ctx["wfs_arrays"] = wfs_arrays + worker_ctx["waveforms_by_units"] = waveforms_by_units worker_ctx["unit_ids"] = unit_ids worker_ctx["spikes"] = spikes @@ -383,10 +383,10 @@ def _worker_ditribute_buffers(segment_index, start_frame, end_frame, worker_ctx) if worker_ctx["mode"] == "memmap": # open file in demand (and also autoclose it after) - filename = worker_ctx["wfs_arrays_info"][unit_id] + filename = worker_ctx["arrays_info"][unit_id] wfs = np.load(str(filename), mmap_mode="r+") elif worker_ctx["mode"] == "shared_memory": - wfs = worker_ctx["wfs_arrays"][unit_id] + wfs = worker_ctx["waveforms_by_units"][unit_id] for pos in in_chunk_pos: sample_index = spikes[inds[pos]]["sample_index"] @@ -398,7 +398,7 @@ def _worker_ditribute_buffers(segment_index, start_frame, end_frame, worker_ctx) wfs[pos, :, :] = wf[:, sparsity_mask[unit_ind]] -def extract_waveforms_to_unique_buffer( +def extract_waveforms_to_single_buffer( recording, spikes, unit_ids, @@ -413,6 +413,57 @@ def extract_waveforms_to_unique_buffer( job_name=None, **job_kwargs, ): + """ + Allocate a single buffer (memmap or or shared memory) and then distribute every waveform into it. + + Contrary to extract_waveforms_to_buffers() all waveforms are extracted in the same buffer, so the spike vector is + needed to recover waveforms unit by unit. Importantly in case of sparsity, the channel are not aligned across + units. + + Important note: for the "shared_memory" mode wf_array_info contains reference to + the shared memmory buffer, this variable must be reference as long as arrays as used. + And this variable is also returned. + To avoid this a copy to non shared memmory can be perform at the end. + + Parameters + ---------- + recording: recording + The recording object + spikes: 1d numpy array with several fields + Spikes handled as a unique vector. + This vector can be obtained with: `spikes = Sorting.to_spike_vector()` + unit_ids: list ot numpy + List of unit_ids + nbefore: int + N samples before spike + nafter: int + N samples after spike + mode: str + Mode to use ('memmap' | 'shared_memory') + return_scaled: bool + Scale traces before exporting to buffer or not. + folder: str or path + In case of memmap mode, folder to save npy files + dtype: numpy.dtype + dtype for waveforms buffer + sparsity_mask: None or array of bool + If not None shape must be must be (len(unit_ids), len(channel_ids)) + copy: bool + If True (default), the output shared memory object is copied to a numpy standard array. + If copy=False then arrays_info is also return. Please keep in mind that arrays_info + need to be referenced as long as waveforms_by_units will be used otherwise it will be very hard to debug. + Also when copy=False the SharedMemory will need to be unlink manually + {} + + Returns + ------- + all_waveforms: numpy array + Single array with shape (nump_spikes, num_samples, num_channels) + + wf_array_info: dict of info + Optionally return in case of shared_memory if copy=False. + Dictionary to "construct" array in workers process (memmap file or sharemem info) + """ nsamples = nbefore + nafter dtype = np.dtype(dtype) @@ -453,8 +504,8 @@ def extract_waveforms_to_unique_buffer( if num_spikes > 0: # and run - func = _worker_ditribute_one_buffer - init_func = _init_worker_ditribute_one_buffer + func = _worker_ditribute_single_buffer + init_func = _init_worker_ditribute_single_buffer init_args = ( recording, @@ -486,7 +537,7 @@ def extract_waveforms_to_unique_buffer( return all_waveforms, wf_array_info -def _init_worker_ditribute_one_buffer( +def _init_worker_ditribute_single_buffer( recording, unit_ids, spikes, wf_array_info, nbefore, nafter, return_scaled, mode, sparsity_mask ): worker_ctx = {} @@ -525,7 +576,7 @@ def _init_worker_ditribute_one_buffer( # used by ChunkRecordingExecutor -def _worker_ditribute_one_buffer(segment_index, start_frame, end_frame, worker_ctx): +def _worker_ditribute_single_buffer(segment_index, start_frame, end_frame, worker_ctx): # recover variables of the worker recording = worker_ctx["recording"] unit_ids = worker_ctx["unit_ids"] @@ -580,19 +631,53 @@ def _worker_ditribute_one_buffer(segment_index, start_frame, end_frame, worker_c wf = wf[:, mask] all_waveforms[spike_ind, :, : wf.shape[1]] = wf + if worker_ctx["mode"] == "memmap": + all_waveforms.flush() + + +def split_waveforms_by_units(unit_ids, spikes, all_waveforms, sparsity_mask=None, folder=None): + """ + Split a single buffer waveforms into waveforms by units (multi buffers or multi files). + + Parameters + ---------- + unit_ids: list or numpy array + List of unit ids + spikes: numpy array + The spike vector + all_waveforms : numpy array + Single buffer containing all waveforms + sparsity_mask : None or numpy array + Optionally the boolean sparsity mask + folder : None or str or Path + If a folde ri sgiven all -def split_waveforms_by_units(unit_ids, spikes, all_waveforms, sparsity_mask=None): - waveform_by_units = {} + Returns + ------- + waveforms_by_units: dict of array + A dict of arrays. + In case of folder not None, this contain the memmap of the files. + """ + if folder is not None: + folder = Path(folder) + waveforms_by_units = {} for unit_index, unit_id in enumerate(unit_ids): mask = spikes["unit_index"] == unit_index if sparsity_mask is not None: chan_mask = sparsity_mask[unit_index, :] num_chans = np.sum(chan_mask) - waveform_by_units[unit_id] = all_waveforms[mask, :, :][:, :, :num_chans] + wfs = all_waveforms[mask, :, :][:, :, :num_chans] + else: + wfs = all_waveforms[mask, :, :] + + if folder is None: + waveforms_by_units[unit_id] = wfs else: - waveform_by_units[unit_id] = all_waveforms[mask, :, :] + np.save(folder / f"waveforms_{unit_id}.npy", wfs) + # this avoid keeping in memory all waveforms + waveforms_by_units[unit_id] = np.load(f"waveforms_{unit_id}.npy", mmap_mode="r") - return waveform_by_units + return waveforms_by_units def has_exceeding_spikes(recording, sorting): From 570a3a5fa8dc1b340da4ee0fc6baa16013213bdb Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 26 Jul 2023 17:57:40 +0200 Subject: [PATCH 07/14] fix local tests --- src/spikeinterface/core/tests/test_waveform_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/tests/test_waveform_tools.py b/src/spikeinterface/core/tests/test_waveform_tools.py index fb65e87458..52d7472c92 100644 --- a/src/spikeinterface/core/tests/test_waveform_tools.py +++ b/src/spikeinterface/core/tests/test_waveform_tools.py @@ -86,7 +86,7 @@ def test_waveform_tools(): wf_folder = cache_folder / f"test_waveform_tools_{j}_{k}_{l}" if wf_folder.is_dir(): shutil.rmtree(wf_folder) - wf_folder.mkdir() + wf_folder.mkdir(parents=True) mode_kwargs_ = dict(**mode_kwargs, folder=wf_folder) else: mode_kwargs_ = mode_kwargs From 34538311d1c00f5f45b4cb84a03984efa9ef4f3d Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 7 Sep 2023 18:38:25 +0200 Subject: [PATCH 08/14] fedeback from alessio and ramon --- src/spikeinterface/core/waveform_tools.py | 23 +++++++++-------------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index 252ea68738..53c0df68df 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -258,8 +258,8 @@ def distribute_waveforms_to_buffers( inds_by_unit[unit_id] = inds # and run - func = _worker_ditribute_buffers - init_func = _init_worker_ditribute_buffers + func = _worker_distribute_buffers + init_func = _init_worker_distribute_buffers init_args = ( recording, @@ -283,7 +283,7 @@ def distribute_waveforms_to_buffers( # used by ChunkRecordingExecutor -def _init_worker_ditribute_buffers( +def _init_worker_distribute_buffers( recording, unit_ids, spikes, arrays_info, nbefore, nafter, return_scaled, inds_by_unit, mode, sparsity_mask ): # create a local dict per worker @@ -329,7 +329,7 @@ def _init_worker_ditribute_buffers( # used by ChunkRecordingExecutor -def _worker_ditribute_buffers(segment_index, start_frame, end_frame, worker_ctx): +def _worker_distribute_buffers(segment_index, start_frame, end_frame, worker_ctx): # recover variables of the worker recording = worker_ctx["recording"] unit_ids = worker_ctx["unit_ids"] @@ -480,7 +480,7 @@ def extract_waveforms_to_single_buffer( shape = (num_spikes, nsamples, num_chans) if mode == "memmap": - filename = str(folder / f"all_waveforms.npy") + filename = str(folder / f"waveforms.npy") all_waveforms = np.lib.format.open_memmap(filename, mode="w+", dtype=dtype, shape=shape) wf_array_info = filename elif mode == "shared_memory": @@ -497,15 +497,10 @@ def extract_waveforms_to_single_buffer( job_kwargs = fix_job_kwargs(job_kwargs) - inds_by_unit = {} - for unit_ind, unit_id in enumerate(unit_ids): - (inds,) = np.nonzero(spikes["unit_index"] == unit_ind) - inds_by_unit[unit_id] = inds - if num_spikes > 0: # and run - func = _worker_ditribute_single_buffer - init_func = _init_worker_ditribute_single_buffer + func = _worker_distribute_single_buffer + init_func = _init_worker_distribute_single_buffer init_args = ( recording, @@ -537,7 +532,7 @@ def extract_waveforms_to_single_buffer( return all_waveforms, wf_array_info -def _init_worker_ditribute_single_buffer( +def _init_worker_distribute_single_buffer( recording, unit_ids, spikes, wf_array_info, nbefore, nafter, return_scaled, mode, sparsity_mask ): worker_ctx = {} @@ -576,7 +571,7 @@ def _init_worker_ditribute_single_buffer( # used by ChunkRecordingExecutor -def _worker_ditribute_single_buffer(segment_index, start_frame, end_frame, worker_ctx): +def _worker_distribute_single_buffer(segment_index, start_frame, end_frame, worker_ctx): # recover variables of the worker recording = worker_ctx["recording"] unit_ids = worker_ctx["unit_ids"] From d79dbe26bb1d3f5db45b1ac36d84f7b96be08f18 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 7 Sep 2023 18:51:26 +0200 Subject: [PATCH 09/14] extract_waveforms_to_single_buffer change folder to file_path --- .../core/tests/test_waveform_tools.py | 27 ++++++++++++------- src/spikeinterface/core/waveform_tools.py | 17 +++++------- 2 files changed, 24 insertions(+), 20 deletions(-) diff --git a/src/spikeinterface/core/tests/test_waveform_tools.py b/src/spikeinterface/core/tests/test_waveform_tools.py index 52d7472c92..1d7e38832a 100644 --- a/src/spikeinterface/core/tests/test_waveform_tools.py +++ b/src/spikeinterface/core/tests/test_waveform_tools.py @@ -59,14 +59,15 @@ def test_waveform_tools(): ] some_modes = [ {"mode": "memmap"}, + {"mode": "shared_memory"}, ] - if platform.system() != "Windows": - # shared memory on windows is buggy... - some_modes.append( - { - "mode": "shared_memory", - } - ) + # if platform.system() != "Windows": + # # shared memory on windows is buggy... + # some_modes.append( + # { + # "mode": "shared_memory", + # } + # ) some_sparsity = [ dict(sparsity_mask=None), @@ -87,9 +88,11 @@ def test_waveform_tools(): if wf_folder.is_dir(): shutil.rmtree(wf_folder) wf_folder.mkdir(parents=True) - mode_kwargs_ = dict(**mode_kwargs, folder=wf_folder) - else: - mode_kwargs_ = mode_kwargs + wf_file_path = wf_folder / "waveforms_all_units.npy" + + mode_kwargs_ = dict(**mode_kwargs) + if mode_kwargs["mode"] == "memmap": + mode_kwargs_["folder" ] = wf_folder wfs_arrays = extract_waveforms_to_buffers( recording, @@ -113,6 +116,10 @@ def test_waveform_tools(): else: list_wfs_sparse.append(wfs_arrays) + mode_kwargs_ = dict(**mode_kwargs) + if mode_kwargs["mode"] == "memmap": + mode_kwargs_["file_path" ] = wf_file_path + all_waveforms = extract_waveforms_to_single_buffer( recording, spikes, diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index 53c0df68df..c363ac49dc 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -406,7 +406,7 @@ def extract_waveforms_to_single_buffer( nafter, mode="memmap", return_scaled=False, - folder=None, + file_path=None, dtype=None, sparsity_mask=None, copy=False, @@ -442,8 +442,8 @@ def extract_waveforms_to_single_buffer( Mode to use ('memmap' | 'shared_memory') return_scaled: bool Scale traces before exporting to buffer or not. - folder: str or path - In case of memmap mode, folder to save npy files + file_path: str or path + In case of memmap mode, file to save npy file. dtype: numpy.dtype dtype for waveforms buffer sparsity_mask: None or array of bool @@ -468,9 +468,9 @@ def extract_waveforms_to_single_buffer( dtype = np.dtype(dtype) if mode == "shared_memory": - assert folder is None + assert file_path is None else: - folder = Path(folder) + file_path = Path(file_path) num_spikes = spikes.size if sparsity_mask is None: @@ -480,9 +480,8 @@ def extract_waveforms_to_single_buffer( shape = (num_spikes, nsamples, num_chans) if mode == "memmap": - filename = str(folder / f"waveforms.npy") - all_waveforms = np.lib.format.open_memmap(filename, mode="w+", dtype=dtype, shape=shape) - wf_array_info = filename + all_waveforms = np.lib.format.open_memmap(file_path, mode="w+", dtype=dtype, shape=shape) + wf_array_info = str(file_path) elif mode == "shared_memory": if num_spikes == 0 or num_chans == 0: all_waveforms = np.zeros(shape, dtype=dtype) @@ -538,7 +537,6 @@ def _init_worker_distribute_single_buffer( worker_ctx = {} worker_ctx["recording"] = recording worker_ctx["wf_array_info"] = wf_array_info - worker_ctx["unit_ids"] = unit_ids worker_ctx["spikes"] = spikes worker_ctx["nbefore"] = nbefore worker_ctx["nafter"] = nafter @@ -574,7 +572,6 @@ def _init_worker_distribute_single_buffer( def _worker_distribute_single_buffer(segment_index, start_frame, end_frame, worker_ctx): # recover variables of the worker recording = worker_ctx["recording"] - unit_ids = worker_ctx["unit_ids"] segment_slices = worker_ctx["segment_slices"] spikes = worker_ctx["spikes"] nbefore = worker_ctx["nbefore"] From 5c6615975704bd9bbbda722291d04ff0ccfb3c90 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 7 Sep 2023 16:51:54 +0000 Subject: [PATCH 10/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/tests/test_waveform_tools.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/tests/test_waveform_tools.py b/src/spikeinterface/core/tests/test_waveform_tools.py index 1d7e38832a..e9cf1bfb5f 100644 --- a/src/spikeinterface/core/tests/test_waveform_tools.py +++ b/src/spikeinterface/core/tests/test_waveform_tools.py @@ -92,7 +92,7 @@ def test_waveform_tools(): mode_kwargs_ = dict(**mode_kwargs) if mode_kwargs["mode"] == "memmap": - mode_kwargs_["folder" ] = wf_folder + mode_kwargs_["folder"] = wf_folder wfs_arrays = extract_waveforms_to_buffers( recording, @@ -118,8 +118,8 @@ def test_waveform_tools(): mode_kwargs_ = dict(**mode_kwargs) if mode_kwargs["mode"] == "memmap": - mode_kwargs_["file_path" ] = wf_file_path - + mode_kwargs_["file_path"] = wf_file_path + all_waveforms = extract_waveforms_to_single_buffer( recording, spikes, From 7080696f12617bfd08769c89a8471768878191ca Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Fri, 8 Sep 2023 10:17:48 +0200 Subject: [PATCH 11/14] Update src/spikeinterface/core/waveform_tools.py Co-authored-by: Alessio Buccino --- src/spikeinterface/core/waveform_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index c363ac49dc..a63d0a80b7 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -417,7 +417,7 @@ def extract_waveforms_to_single_buffer( Allocate a single buffer (memmap or or shared memory) and then distribute every waveform into it. Contrary to extract_waveforms_to_buffers() all waveforms are extracted in the same buffer, so the spike vector is - needed to recover waveforms unit by unit. Importantly in case of sparsity, the channel are not aligned across + needed to recover waveforms unit by unit. Importantly in case of sparsity, the channels are not aligned across units. Important note: for the "shared_memory" mode wf_array_info contains reference to From e5a523c9263fa1a229e89905639496da03dd39e0 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 12 Sep 2023 10:12:54 +0200 Subject: [PATCH 12/14] Improvement after Ramon comments. --- src/spikeinterface/core/waveform_tools.py | 67 +++++++++++------------ 1 file changed, 31 insertions(+), 36 deletions(-) diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index a63d0a80b7..6e0d6f412b 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -350,16 +350,10 @@ def _worker_distribute_buffers(segment_index, start_frame, end_frame, worker_ctx # take only spikes in range [start_frame, end_frame] # this is a slice so no copy!! - i0 = np.searchsorted(in_seg_spikes["sample_index"], start_frame) - i1 = np.searchsorted(in_seg_spikes["sample_index"], end_frame) - if i0 != i1: - # protect from spikes on border : spike_time<0 or spike_time>seg_size - # useful only when max_spikes_per_unit is not None - # waveform will not be extracted and a zeros will be left in the memmap file - while (in_seg_spikes[i0]["sample_index"] - nbefore) < 0 and (i0 != i1): - i0 = i0 + 1 - while (in_seg_spikes[i1 - 1]["sample_index"] + nafter) > seg_size and (i0 != i1): - i1 = i1 - 1 + # the border of segment are protected by nbefore on left an nafter on the right + i0 = np.searchsorted(in_seg_spikes["sample_index"], max(start_frame, nbefore)) + i1 = np.searchsorted(in_seg_spikes["sample_index"], min(end_frame, seg_size - nafter)) + # slice in absolut in spikes vector l0 = i0 + s0 @@ -420,6 +414,9 @@ def extract_waveforms_to_single_buffer( needed to recover waveforms unit by unit. Importantly in case of sparsity, the channels are not aligned across units. + Note: spikes near borders (nbefore/nafter) are not extracted and 0 are put the output buffer. + This ensures that spikes.shape[0] == all_waveforms.shape[0]. + Important note: for the "shared_memory" mode wf_array_info contains reference to the shared memmory buffer, this variable must be reference as long as arrays as used. And this variable is also returned. @@ -449,10 +446,13 @@ def extract_waveforms_to_single_buffer( sparsity_mask: None or array of bool If not None shape must be must be (len(unit_ids), len(channel_ids)) copy: bool - If True (default), the output shared memory object is copied to a numpy standard array. - If copy=False then arrays_info is also return. Please keep in mind that arrays_info - need to be referenced as long as waveforms_by_units will be used otherwise it will be very hard to debug. - Also when copy=False the SharedMemory will need to be unlink manually + If True (default), the output shared memory object is copied to a numpy standard array and no reference + to the internal shared memory object is kept. + If copy=False then the shared memory object is also returned. Please keep in mind that the shared memory object + need to be referenced as long as all_waveforms will be used otherwise it might produce segmentation + faults which are hard to debug. + Also when copy=False the SharedMemory will need to be unlink manually if proper cleanup of resources is desired. + {} Returns @@ -481,7 +481,8 @@ def extract_waveforms_to_single_buffer( if mode == "memmap": all_waveforms = np.lib.format.open_memmap(file_path, mode="w+", dtype=dtype, shape=shape) - wf_array_info = str(file_path) + # wf_array_info = str(file_path) + wf_array_info = dict(filename=str(file_path)) elif mode == "shared_memory": if num_spikes == 0 or num_chans == 0: all_waveforms = np.zeros(shape, dtype=dtype) @@ -490,7 +491,8 @@ def extract_waveforms_to_single_buffer( else: all_waveforms, shm = make_shared_array(shape, dtype) shm_name = shm.name - wf_array_info = (shm, shm_name, dtype.str, shape) + # wf_array_info = (shm, shm_name, dtype.str, shape) + wf_array_info = dict(shm=shm, shm_name=shm_name, dtype=dtype.str, shape=shape) else: raise ValueError("allocate_waveforms_buffers bad mode") @@ -503,7 +505,6 @@ def extract_waveforms_to_single_buffer( init_args = ( recording, - unit_ids, spikes, wf_array_info, nbefore, @@ -532,7 +533,7 @@ def extract_waveforms_to_single_buffer( def _init_worker_distribute_single_buffer( - recording, unit_ids, spikes, wf_array_info, nbefore, nafter, return_scaled, mode, sparsity_mask + recording, spikes, wf_array_info, nbefore, nafter, return_scaled, mode, sparsity_mask ): worker_ctx = {} worker_ctx["recording"] = recording @@ -545,13 +546,13 @@ def _init_worker_distribute_single_buffer( worker_ctx["mode"] = mode if mode == "memmap": - filename = wf_array_info + filename = wf_array_info["filename"] all_waveforms = np.load(str(filename), mmap_mode="r+") worker_ctx["all_waveforms"] = all_waveforms elif mode == "shared_memory": from multiprocessing.shared_memory import SharedMemory - shm, shm_name, dtype, shape = wf_array_info + shm_name, dtype, shape = wf_array_info["shm_name"], wf_array_info["dtype"], wf_array_info["shape"] shm = SharedMemory(shm_name) all_waveforms = np.ndarray(shape=shape, dtype=dtype, buffer=shm.buf) worker_ctx["shm"] = shm @@ -587,16 +588,10 @@ def _worker_distribute_single_buffer(segment_index, start_frame, end_frame, work # take only spikes in range [start_frame, end_frame] # this is a slice so no copy!! - i0 = np.searchsorted(in_seg_spikes["sample_index"], start_frame) - i1 = np.searchsorted(in_seg_spikes["sample_index"], end_frame) - if i0 != i1: - # protect from spikes on border : spike_time<0 or spike_time>seg_size - # useful only when max_spikes_per_unit is not None - # waveform will not be extracted and a zeros will be left in the memmap file - while (in_seg_spikes[i0]["sample_index"] - nbefore) < 0 and (i0 != i1): - i0 = i0 + 1 - while (in_seg_spikes[i1 - 1]["sample_index"] + nafter) > seg_size and (i0 != i1): - i1 = i1 - 1 + # the border of segment are protected by nbefore on left an nafter on the right + i0 = np.searchsorted(in_seg_spikes["sample_index"], max(start_frame, nbefore)) + i1 = np.searchsorted(in_seg_spikes["sample_index"], min(end_frame, seg_size - nafter)) + # slice in absolut in spikes vector l0 = i0 + s0 @@ -611,17 +606,17 @@ def _worker_distribute_single_buffer(segment_index, start_frame, end_frame, work start_frame=start, end_frame=end, segment_index=segment_index, return_scaled=return_scaled ) - for spike_ind in range(l0, l1): - sample_index = spikes[spike_ind]["sample_index"] - unit_index = spikes[spike_ind]["unit_index"] + for spike_index in range(l0, l1): + sample_index = spikes[spike_index]["sample_index"] + unit_index = spikes[spike_index]["unit_index"] wf = traces[sample_index - start - nbefore : sample_index - start + nafter, :] if sparsity_mask is None: - all_waveforms[spike_ind, :, :] = wf + all_waveforms[spike_index, :, :] = wf else: mask = sparsity_mask[unit_index, :] wf = wf[:, mask] - all_waveforms[spike_ind, :, : wf.shape[1]] = wf + all_waveforms[spike_index, :, : wf.shape[1]] = wf if worker_ctx["mode"] == "memmap": all_waveforms.flush() @@ -642,7 +637,7 @@ def split_waveforms_by_units(unit_ids, spikes, all_waveforms, sparsity_mask=None sparsity_mask : None or numpy array Optionally the boolean sparsity mask folder : None or str or Path - If a folde ri sgiven all + If a folder is given all waveforms by units are copied in a npy file using f"waveforms_{unit_id}.npy" naming. Returns ------- From e37b0515742b129984eb75da35c869a1de6b78d3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 12 Sep 2023 08:13:32 +0000 Subject: [PATCH 13/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/waveform_tools.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index 6e0d6f412b..39623da329 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -354,7 +354,6 @@ def _worker_distribute_buffers(segment_index, start_frame, end_frame, worker_ctx i0 = np.searchsorted(in_seg_spikes["sample_index"], max(start_frame, nbefore)) i1 = np.searchsorted(in_seg_spikes["sample_index"], min(end_frame, seg_size - nafter)) - # slice in absolut in spikes vector l0 = i0 + s0 l1 = i1 + s0 @@ -416,7 +415,7 @@ def extract_waveforms_to_single_buffer( Note: spikes near borders (nbefore/nafter) are not extracted and 0 are put the output buffer. This ensures that spikes.shape[0] == all_waveforms.shape[0]. - + Important note: for the "shared_memory" mode wf_array_info contains reference to the shared memmory buffer, this variable must be reference as long as arrays as used. And this variable is also returned. @@ -481,7 +480,7 @@ def extract_waveforms_to_single_buffer( if mode == "memmap": all_waveforms = np.lib.format.open_memmap(file_path, mode="w+", dtype=dtype, shape=shape) - # wf_array_info = str(file_path) + # wf_array_info = str(file_path) wf_array_info = dict(filename=str(file_path)) elif mode == "shared_memory": if num_spikes == 0 or num_chans == 0: @@ -592,7 +591,6 @@ def _worker_distribute_single_buffer(segment_index, start_frame, end_frame, work i0 = np.searchsorted(in_seg_spikes["sample_index"], max(start_frame, nbefore)) i1 = np.searchsorted(in_seg_spikes["sample_index"], min(end_frame, seg_size - nafter)) - # slice in absolut in spikes vector l0 = i0 + s0 l1 = i1 + s0 From 966d56a9fa150472e43f888ad4e6b62f89a77ef6 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 13 Sep 2023 08:18:51 +0200 Subject: [PATCH 14/14] doc --- src/spikeinterface/core/waveform_tools.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index 39623da329..da8e3d64b6 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -417,9 +417,11 @@ def extract_waveforms_to_single_buffer( This ensures that spikes.shape[0] == all_waveforms.shape[0]. Important note: for the "shared_memory" mode wf_array_info contains reference to - the shared memmory buffer, this variable must be reference as long as arrays as used. - And this variable is also returned. - To avoid this a copy to non shared memmory can be perform at the end. + the shared memmory buffer, this variable must be referenced as long as arrays is used. + This variable must also unlink() when the array is de-referenced. + To avoid this complicated behavior, by default (copy=True) the shared memmory buffer is copied into a standard + numpy array. + Parameters ----------