Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Waveform tools speedup #1799

Merged
merged 18 commits into from
Sep 13, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 79 additions & 85 deletions src/spikeinterface/core/tests/test_waveform_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from spikeinterface.core import generate_recording, generate_sorting
from spikeinterface.core.waveform_tools import (
extract_waveforms_to_buffers,
) # allocate_waveforms_buffers, distribute_waveforms_to_buffers
extract_waveforms_to_single_buffer,
split_waveforms_by_units,
)


if hasattr(pytest, "global_test_folder"):
Expand Down Expand Up @@ -52,96 +54,88 @@ def test_waveform_tools():
unit_ids = sorting.unit_ids

some_job_kwargs = [
{},
{"n_jobs": 1, "chunk_size": 3000, "progress_bar": True},
{"n_jobs": 2, "chunk_size": 3000, "progress_bar": True},
]
samuelgarcia marked this conversation as resolved.
Show resolved Hide resolved
some_modes = [
{"mode": "memmap"},
]
if platform.system() != "Windows":
# shared memory on windows is buggy...
some_modes.append(
{
"mode": "shared_memory",
}
)

some_sparsity = [
dict(sparsity_mask=None),
dict(sparsity_mask=np.random.randint(0, 2, size=(unit_ids.size, recording.channel_ids.size), dtype="bool")),
]

# memmap mode
list_wfs = []
list_wfs_dense = []
list_wfs_sparse = []
for j, job_kwargs in enumerate(some_job_kwargs):
samuelgarcia marked this conversation as resolved.
Show resolved Hide resolved
wf_folder = cache_folder / f"test_waveform_tools_{j}"
if wf_folder.is_dir():
shutil.rmtree(wf_folder)
wf_folder.mkdir(parents=True)
# wfs_arrays, wfs_arrays_info = allocate_waveforms_buffers(recording, spikes, unit_ids, nbefore, nafter, mode='memmap', folder=wf_folder, dtype=dtype)
# distribute_waveforms_to_buffers(recording, spikes, unit_ids, wfs_arrays_info, nbefore, nafter, return_scaled, **job_kwargs)
wfs_arrays = extract_waveforms_to_buffers(
recording,
spikes,
unit_ids,
nbefore,
nafter,
mode="memmap",
return_scaled=False,
folder=wf_folder,
dtype=dtype,
sparsity_mask=None,
copy=False,
**job_kwargs,
)
for unit_ind, unit_id in enumerate(unit_ids):
wf = wfs_arrays[unit_id]
assert wf.shape[0] == np.sum(spikes["unit_index"] == unit_ind)
list_wfs.append({unit_id: wfs_arrays[unit_id].copy() for unit_id in unit_ids})
_check_all_wf_equal(list_wfs)

# memory
if platform.system() != "Windows":
# shared memory on windows is buggy...
list_wfs = []
for job_kwargs in some_job_kwargs:
# wfs_arrays, wfs_arrays_info = allocate_waveforms_buffers(recording, spikes, unit_ids, nbefore, nafter, mode='shared_memory', folder=None, dtype=dtype)
# distribute_waveforms_to_buffers(recording, spikes, unit_ids, wfs_arrays_info, nbefore, nafter, return_scaled, mode='shared_memory', **job_kwargs)
wfs_arrays = extract_waveforms_to_buffers(
recording,
spikes,
unit_ids,
nbefore,
nafter,
mode="shared_memory",
return_scaled=False,
folder=None,
dtype=dtype,
sparsity_mask=None,
copy=True,
**job_kwargs,
)
for unit_ind, unit_id in enumerate(unit_ids):
wf = wfs_arrays[unit_id]
assert wf.shape[0] == np.sum(spikes["unit_index"] == unit_ind)
list_wfs.append({unit_id: wfs_arrays[unit_id].copy() for unit_id in unit_ids})
# to avoid warning we need to first destroy arrays then sharedmemm object
# del wfs_arrays
# del wfs_arrays_info
_check_all_wf_equal(list_wfs)

# with sparsity
wf_folder = cache_folder / "test_waveform_tools_sparse"
if wf_folder.is_dir():
shutil.rmtree(wf_folder)
wf_folder.mkdir()

sparsity_mask = np.random.randint(0, 2, size=(unit_ids.size, recording.channel_ids.size), dtype="bool")
job_kwargs = {"n_jobs": 1, "chunk_size": 3000, "progress_bar": True}

# wfs_arrays, wfs_arrays_info = allocate_waveforms_buffers(recording, spikes, unit_ids, nbefore, nafter, mode='memmap', folder=wf_folder, dtype=dtype, sparsity_mask=sparsity_mask)
# distribute_waveforms_to_buffers(recording, spikes, unit_ids, wfs_arrays_info, nbefore, nafter, return_scaled, sparsity_mask=sparsity_mask, **job_kwargs)

wfs_arrays = extract_waveforms_to_buffers(
recording,
spikes,
unit_ids,
nbefore,
nafter,
mode="memmap",
return_scaled=False,
folder=wf_folder,
dtype=dtype,
sparsity_mask=sparsity_mask,
copy=False,
**job_kwargs,
)
for k, mode_kwargs in enumerate(some_modes):
for l, sparsity_kwargs in enumerate(some_sparsity):
# print()
# print(job_kwargs, mode_kwargs, 'sparse=', sparsity_kwargs['sparsity_mask'] is None)

if mode_kwargs["mode"] == "memmap":
wf_folder = cache_folder / f"test_waveform_tools_{j}_{k}_{l}"
if wf_folder.is_dir():
shutil.rmtree(wf_folder)
wf_folder.mkdir(parents=True)
mode_kwargs_ = dict(**mode_kwargs, folder=wf_folder)
else:
mode_kwargs_ = mode_kwargs

wfs_arrays = extract_waveforms_to_buffers(
recording,
spikes,
unit_ids,
nbefore,
nafter,
return_scaled=False,
dtype=dtype,
copy=True,
**sparsity_kwargs,
**mode_kwargs_,
**job_kwargs,
)
for unit_ind, unit_id in enumerate(unit_ids):
wf = wfs_arrays[unit_id]
assert wf.shape[0] == np.sum(spikes["unit_index"] == unit_ind)

if sparsity_kwargs["sparsity_mask"] is None:
list_wfs_dense.append(wfs_arrays)
else:
list_wfs_sparse.append(wfs_arrays)

all_waveforms = extract_waveforms_to_single_buffer(
recording,
spikes,
unit_ids,
nbefore,
nafter,
return_scaled=False,
dtype=dtype,
copy=True,
**sparsity_kwargs,
**mode_kwargs_,
**job_kwargs,
)
wfs_arrays = split_waveforms_by_units(
unit_ids, spikes, all_waveforms, sparsity_mask=sparsity_kwargs["sparsity_mask"]
)
if sparsity_kwargs["sparsity_mask"] is None:
list_wfs_dense.append(wfs_arrays)
else:
list_wfs_sparse.append(wfs_arrays)

_check_all_wf_equal(list_wfs_dense)
_check_all_wf_equal(list_wfs_sparse)


if __name__ == "__main__":
Expand Down
Loading