Skip to content

Commit

Permalink
Improvement after Ramon comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelgarcia committed Sep 12, 2023
1 parent 7080696 commit e5a523c
Showing 1 changed file with 31 additions and 36 deletions.
67 changes: 31 additions & 36 deletions src/spikeinterface/core/waveform_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,16 +350,10 @@ def _worker_distribute_buffers(segment_index, start_frame, end_frame, worker_ctx

# take only spikes in range [start_frame, end_frame]
# this is a slice so no copy!!
i0 = np.searchsorted(in_seg_spikes["sample_index"], start_frame)
i1 = np.searchsorted(in_seg_spikes["sample_index"], end_frame)
if i0 != i1:
# protect from spikes on border : spike_time<0 or spike_time>seg_size
# useful only when max_spikes_per_unit is not None
# waveform will not be extracted and a zeros will be left in the memmap file
while (in_seg_spikes[i0]["sample_index"] - nbefore) < 0 and (i0 != i1):
i0 = i0 + 1
while (in_seg_spikes[i1 - 1]["sample_index"] + nafter) > seg_size and (i0 != i1):
i1 = i1 - 1
# the border of segment are protected by nbefore on left an nafter on the right
i0 = np.searchsorted(in_seg_spikes["sample_index"], max(start_frame, nbefore))
i1 = np.searchsorted(in_seg_spikes["sample_index"], min(end_frame, seg_size - nafter))


# slice in absolut in spikes vector
l0 = i0 + s0
Expand Down Expand Up @@ -420,6 +414,9 @@ def extract_waveforms_to_single_buffer(
needed to recover waveforms unit by unit. Importantly in case of sparsity, the channels are not aligned across
units.
Note: spikes near borders (nbefore/nafter) are not extracted and 0 are put the output buffer.
This ensures that spikes.shape[0] == all_waveforms.shape[0].
Important note: for the "shared_memory" mode wf_array_info contains reference to
the shared memmory buffer, this variable must be reference as long as arrays as used.
And this variable is also returned.
Expand Down Expand Up @@ -449,10 +446,13 @@ def extract_waveforms_to_single_buffer(
sparsity_mask: None or array of bool
If not None shape must be must be (len(unit_ids), len(channel_ids))
copy: bool
If True (default), the output shared memory object is copied to a numpy standard array.
If copy=False then arrays_info is also return. Please keep in mind that arrays_info
need to be referenced as long as waveforms_by_units will be used otherwise it will be very hard to debug.
Also when copy=False the SharedMemory will need to be unlink manually
If True (default), the output shared memory object is copied to a numpy standard array and no reference
to the internal shared memory object is kept.
If copy=False then the shared memory object is also returned. Please keep in mind that the shared memory object
need to be referenced as long as all_waveforms will be used otherwise it might produce segmentation
faults which are hard to debug.
Also when copy=False the SharedMemory will need to be unlink manually if proper cleanup of resources is desired.
{}
Returns
Expand Down Expand Up @@ -481,7 +481,8 @@ def extract_waveforms_to_single_buffer(

if mode == "memmap":
all_waveforms = np.lib.format.open_memmap(file_path, mode="w+", dtype=dtype, shape=shape)
wf_array_info = str(file_path)
# wf_array_info = str(file_path)
wf_array_info = dict(filename=str(file_path))
elif mode == "shared_memory":
if num_spikes == 0 or num_chans == 0:
all_waveforms = np.zeros(shape, dtype=dtype)
Expand All @@ -490,7 +491,8 @@ def extract_waveforms_to_single_buffer(
else:
all_waveforms, shm = make_shared_array(shape, dtype)
shm_name = shm.name
wf_array_info = (shm, shm_name, dtype.str, shape)
# wf_array_info = (shm, shm_name, dtype.str, shape)
wf_array_info = dict(shm=shm, shm_name=shm_name, dtype=dtype.str, shape=shape)
else:
raise ValueError("allocate_waveforms_buffers bad mode")

Expand All @@ -503,7 +505,6 @@ def extract_waveforms_to_single_buffer(

init_args = (
recording,
unit_ids,
spikes,
wf_array_info,
nbefore,
Expand Down Expand Up @@ -532,7 +533,7 @@ def extract_waveforms_to_single_buffer(


def _init_worker_distribute_single_buffer(
recording, unit_ids, spikes, wf_array_info, nbefore, nafter, return_scaled, mode, sparsity_mask
recording, spikes, wf_array_info, nbefore, nafter, return_scaled, mode, sparsity_mask
):
worker_ctx = {}
worker_ctx["recording"] = recording
Expand All @@ -545,13 +546,13 @@ def _init_worker_distribute_single_buffer(
worker_ctx["mode"] = mode

if mode == "memmap":
filename = wf_array_info
filename = wf_array_info["filename"]
all_waveforms = np.load(str(filename), mmap_mode="r+")
worker_ctx["all_waveforms"] = all_waveforms
elif mode == "shared_memory":
from multiprocessing.shared_memory import SharedMemory

shm, shm_name, dtype, shape = wf_array_info
shm_name, dtype, shape = wf_array_info["shm_name"], wf_array_info["dtype"], wf_array_info["shape"]
shm = SharedMemory(shm_name)
all_waveforms = np.ndarray(shape=shape, dtype=dtype, buffer=shm.buf)
worker_ctx["shm"] = shm
Expand Down Expand Up @@ -587,16 +588,10 @@ def _worker_distribute_single_buffer(segment_index, start_frame, end_frame, work

# take only spikes in range [start_frame, end_frame]
# this is a slice so no copy!!
i0 = np.searchsorted(in_seg_spikes["sample_index"], start_frame)
i1 = np.searchsorted(in_seg_spikes["sample_index"], end_frame)
if i0 != i1:
# protect from spikes on border : spike_time<0 or spike_time>seg_size
# useful only when max_spikes_per_unit is not None
# waveform will not be extracted and a zeros will be left in the memmap file
while (in_seg_spikes[i0]["sample_index"] - nbefore) < 0 and (i0 != i1):
i0 = i0 + 1
while (in_seg_spikes[i1 - 1]["sample_index"] + nafter) > seg_size and (i0 != i1):
i1 = i1 - 1
# the border of segment are protected by nbefore on left an nafter on the right
i0 = np.searchsorted(in_seg_spikes["sample_index"], max(start_frame, nbefore))
i1 = np.searchsorted(in_seg_spikes["sample_index"], min(end_frame, seg_size - nafter))


# slice in absolut in spikes vector
l0 = i0 + s0
Expand All @@ -611,17 +606,17 @@ def _worker_distribute_single_buffer(segment_index, start_frame, end_frame, work
start_frame=start, end_frame=end, segment_index=segment_index, return_scaled=return_scaled
)

for spike_ind in range(l0, l1):
sample_index = spikes[spike_ind]["sample_index"]
unit_index = spikes[spike_ind]["unit_index"]
for spike_index in range(l0, l1):
sample_index = spikes[spike_index]["sample_index"]
unit_index = spikes[spike_index]["unit_index"]
wf = traces[sample_index - start - nbefore : sample_index - start + nafter, :]

if sparsity_mask is None:
all_waveforms[spike_ind, :, :] = wf
all_waveforms[spike_index, :, :] = wf
else:
mask = sparsity_mask[unit_index, :]
wf = wf[:, mask]
all_waveforms[spike_ind, :, : wf.shape[1]] = wf
all_waveforms[spike_index, :, : wf.shape[1]] = wf

if worker_ctx["mode"] == "memmap":
all_waveforms.flush()
Expand All @@ -642,7 +637,7 @@ def split_waveforms_by_units(unit_ids, spikes, all_waveforms, sparsity_mask=None
sparsity_mask : None or numpy array
Optionally the boolean sparsity mask
folder : None or str or Path
If a folde ri sgiven all
If a folder is given all waveforms by units are copied in a npy file using f"waveforms_{unit_id}.npy" naming.
Returns
-------
Expand Down

0 comments on commit e5a523c

Please sign in to comment.