From e5a523c9263fa1a229e89905639496da03dd39e0 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 12 Sep 2023 10:12:54 +0200 Subject: [PATCH] Improvement after Ramon comments. --- src/spikeinterface/core/waveform_tools.py | 67 +++++++++++------------ 1 file changed, 31 insertions(+), 36 deletions(-) diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index a63d0a80b7..6e0d6f412b 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -350,16 +350,10 @@ def _worker_distribute_buffers(segment_index, start_frame, end_frame, worker_ctx # take only spikes in range [start_frame, end_frame] # this is a slice so no copy!! - i0 = np.searchsorted(in_seg_spikes["sample_index"], start_frame) - i1 = np.searchsorted(in_seg_spikes["sample_index"], end_frame) - if i0 != i1: - # protect from spikes on border : spike_time<0 or spike_time>seg_size - # useful only when max_spikes_per_unit is not None - # waveform will not be extracted and a zeros will be left in the memmap file - while (in_seg_spikes[i0]["sample_index"] - nbefore) < 0 and (i0 != i1): - i0 = i0 + 1 - while (in_seg_spikes[i1 - 1]["sample_index"] + nafter) > seg_size and (i0 != i1): - i1 = i1 - 1 + # the border of segment are protected by nbefore on left an nafter on the right + i0 = np.searchsorted(in_seg_spikes["sample_index"], max(start_frame, nbefore)) + i1 = np.searchsorted(in_seg_spikes["sample_index"], min(end_frame, seg_size - nafter)) + # slice in absolut in spikes vector l0 = i0 + s0 @@ -420,6 +414,9 @@ def extract_waveforms_to_single_buffer( needed to recover waveforms unit by unit. Importantly in case of sparsity, the channels are not aligned across units. + Note: spikes near borders (nbefore/nafter) are not extracted and 0 are put the output buffer. + This ensures that spikes.shape[0] == all_waveforms.shape[0]. + Important note: for the "shared_memory" mode wf_array_info contains reference to the shared memmory buffer, this variable must be reference as long as arrays as used. And this variable is also returned. @@ -449,10 +446,13 @@ def extract_waveforms_to_single_buffer( sparsity_mask: None or array of bool If not None shape must be must be (len(unit_ids), len(channel_ids)) copy: bool - If True (default), the output shared memory object is copied to a numpy standard array. - If copy=False then arrays_info is also return. Please keep in mind that arrays_info - need to be referenced as long as waveforms_by_units will be used otherwise it will be very hard to debug. - Also when copy=False the SharedMemory will need to be unlink manually + If True (default), the output shared memory object is copied to a numpy standard array and no reference + to the internal shared memory object is kept. + If copy=False then the shared memory object is also returned. Please keep in mind that the shared memory object + need to be referenced as long as all_waveforms will be used otherwise it might produce segmentation + faults which are hard to debug. + Also when copy=False the SharedMemory will need to be unlink manually if proper cleanup of resources is desired. + {} Returns @@ -481,7 +481,8 @@ def extract_waveforms_to_single_buffer( if mode == "memmap": all_waveforms = np.lib.format.open_memmap(file_path, mode="w+", dtype=dtype, shape=shape) - wf_array_info = str(file_path) + # wf_array_info = str(file_path) + wf_array_info = dict(filename=str(file_path)) elif mode == "shared_memory": if num_spikes == 0 or num_chans == 0: all_waveforms = np.zeros(shape, dtype=dtype) @@ -490,7 +491,8 @@ def extract_waveforms_to_single_buffer( else: all_waveforms, shm = make_shared_array(shape, dtype) shm_name = shm.name - wf_array_info = (shm, shm_name, dtype.str, shape) + # wf_array_info = (shm, shm_name, dtype.str, shape) + wf_array_info = dict(shm=shm, shm_name=shm_name, dtype=dtype.str, shape=shape) else: raise ValueError("allocate_waveforms_buffers bad mode") @@ -503,7 +505,6 @@ def extract_waveforms_to_single_buffer( init_args = ( recording, - unit_ids, spikes, wf_array_info, nbefore, @@ -532,7 +533,7 @@ def extract_waveforms_to_single_buffer( def _init_worker_distribute_single_buffer( - recording, unit_ids, spikes, wf_array_info, nbefore, nafter, return_scaled, mode, sparsity_mask + recording, spikes, wf_array_info, nbefore, nafter, return_scaled, mode, sparsity_mask ): worker_ctx = {} worker_ctx["recording"] = recording @@ -545,13 +546,13 @@ def _init_worker_distribute_single_buffer( worker_ctx["mode"] = mode if mode == "memmap": - filename = wf_array_info + filename = wf_array_info["filename"] all_waveforms = np.load(str(filename), mmap_mode="r+") worker_ctx["all_waveforms"] = all_waveforms elif mode == "shared_memory": from multiprocessing.shared_memory import SharedMemory - shm, shm_name, dtype, shape = wf_array_info + shm_name, dtype, shape = wf_array_info["shm_name"], wf_array_info["dtype"], wf_array_info["shape"] shm = SharedMemory(shm_name) all_waveforms = np.ndarray(shape=shape, dtype=dtype, buffer=shm.buf) worker_ctx["shm"] = shm @@ -587,16 +588,10 @@ def _worker_distribute_single_buffer(segment_index, start_frame, end_frame, work # take only spikes in range [start_frame, end_frame] # this is a slice so no copy!! - i0 = np.searchsorted(in_seg_spikes["sample_index"], start_frame) - i1 = np.searchsorted(in_seg_spikes["sample_index"], end_frame) - if i0 != i1: - # protect from spikes on border : spike_time<0 or spike_time>seg_size - # useful only when max_spikes_per_unit is not None - # waveform will not be extracted and a zeros will be left in the memmap file - while (in_seg_spikes[i0]["sample_index"] - nbefore) < 0 and (i0 != i1): - i0 = i0 + 1 - while (in_seg_spikes[i1 - 1]["sample_index"] + nafter) > seg_size and (i0 != i1): - i1 = i1 - 1 + # the border of segment are protected by nbefore on left an nafter on the right + i0 = np.searchsorted(in_seg_spikes["sample_index"], max(start_frame, nbefore)) + i1 = np.searchsorted(in_seg_spikes["sample_index"], min(end_frame, seg_size - nafter)) + # slice in absolut in spikes vector l0 = i0 + s0 @@ -611,17 +606,17 @@ def _worker_distribute_single_buffer(segment_index, start_frame, end_frame, work start_frame=start, end_frame=end, segment_index=segment_index, return_scaled=return_scaled ) - for spike_ind in range(l0, l1): - sample_index = spikes[spike_ind]["sample_index"] - unit_index = spikes[spike_ind]["unit_index"] + for spike_index in range(l0, l1): + sample_index = spikes[spike_index]["sample_index"] + unit_index = spikes[spike_index]["unit_index"] wf = traces[sample_index - start - nbefore : sample_index - start + nafter, :] if sparsity_mask is None: - all_waveforms[spike_ind, :, :] = wf + all_waveforms[spike_index, :, :] = wf else: mask = sparsity_mask[unit_index, :] wf = wf[:, mask] - all_waveforms[spike_ind, :, : wf.shape[1]] = wf + all_waveforms[spike_index, :, : wf.shape[1]] = wf if worker_ctx["mode"] == "memmap": all_waveforms.flush() @@ -642,7 +637,7 @@ def split_waveforms_by_units(unit_ids, spikes, all_waveforms, sparsity_mask=None sparsity_mask : None or numpy array Optionally the boolean sparsity mask folder : None or str or Path - If a folde ri sgiven all + If a folder is given all waveforms by units are copied in a npy file using f"waveforms_{unit_id}.npy" naming. Returns -------