Skip to content

Commit

Permalink
Merge branch 'main' into probeinterface-update
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelgarcia authored Oct 30, 2023
2 parents 1b9ab35 + 1a23d1d commit 8ec4931
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 26 deletions.
7 changes: 7 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,13 @@ widgets = [
"sortingview>=0.11.15",
]

qualitymetrics = [
"scikit-learn",
"scipy",
"pandas",
"numba",
]

test_core = [
"pytest",
"zarr",
Expand Down
42 changes: 42 additions & 0 deletions src/spikeinterface/core/tests/test_waveform_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,48 @@ def test_extract_waveforms():
)
assert we4.sparsity is not None

# test with sparsity estimation
folder5 = cache_folder / "test_extract_waveforms_compute_sparsity_tmp_folder"
sparsity_temp_folder = cache_folder / "tmp_sparsity"
if folder5.is_dir():
shutil.rmtree(folder5)

we5 = extract_waveforms(
recording,
sorting,
folder5,
max_spikes_per_unit=100,
return_scaled=True,
sparse=True,
sparsity_temp_folder=sparsity_temp_folder,
method="radius",
radius_um=50.0,
n_jobs=2,
chunk_duration="500ms",
)
assert we5.sparsity is not None
# tmp folder is cleaned up
assert not sparsity_temp_folder.is_dir()

# should raise an error if sparsity_temp_folder is not empty
with pytest.raises(AssertionError):
if folder5.is_dir():
shutil.rmtree(folder5)
sparsity_temp_folder.mkdir()
we5 = extract_waveforms(
recording,
sorting,
folder5,
max_spikes_per_unit=100,
return_scaled=True,
sparse=True,
sparsity_temp_folder=sparsity_temp_folder,
method="radius",
radius_um=50.0,
n_jobs=2,
chunk_duration="500ms",
)


def test_recordingless():
durations = [30, 40]
Expand Down
79 changes: 56 additions & 23 deletions src/spikeinterface/core/waveform_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1481,7 +1481,9 @@ def extract_waveforms(
dtype=None,
sparse=True,
sparsity=None,
sparsity_temp_folder=None,
num_spikes_for_sparsity=100,
unit_batch_size=200,
allow_unfiltered=False,
use_relative_path=False,
seed=None,
Expand Down Expand Up @@ -1531,8 +1533,15 @@ def extract_waveforms(
sparsity you want to apply (by radius, by best channels, ...).
sparsity: ChannelSparsity or None
The sparsity used to compute waveforms. If this is given, `sparse` is ignored. Default None.
sparsity_temp_folder: str or Path or None, default: None
If sparse is True, this is the temporary folder where the dense waveforms are temporarily saved.
If None, dense waveforms are extracted in memory in batches (which can be controlled by the `unit_batch_size`
parameter. With a large number of units (e.g., > 400), it is advisable to use a temporary folder.
num_spikes_for_sparsity: int (default 100)
The number of spikes to use to estimate sparsity (if sparse=True).
unit_batch_size: int, default 200
The number of units to process at once when extracting dense waveforms (if sparse=True and sparsity_temp_folder
is None).
allow_unfiltered: bool
If true, will accept an allow_unfiltered recording.
False by default.
Expand Down Expand Up @@ -1612,6 +1621,8 @@ def extract_waveforms(
ms_before=ms_before,
ms_after=ms_after,
num_spikes_for_sparsity=num_spikes_for_sparsity,
unit_batch_size=unit_batch_size,
temp_folder=sparsity_temp_folder,
allow_unfiltered=allow_unfiltered,
**estimate_kwargs,
**job_kwargs,
Expand Down Expand Up @@ -1675,6 +1686,7 @@ def precompute_sparsity(
unit_batch_size=200,
ms_before=2.0,
ms_after=3.0,
temp_folder=None,
allow_unfiltered=False,
**kwargs,
):
Expand All @@ -1689,25 +1701,25 @@ def precompute_sparsity(
The recording object
sorting: Sorting
The sorting object
num_spikes_for_sparsity: int
How many spikes per unit.
unit_batch_size: int or None
num_spikes_for_sparsity: int, default 100
How many spikes per unit
unit_batch_size: int or None, default 200
How many units are extracted at once to estimate sparsity.
If None then they are extracted all at one (consum many memory)
ms_before: float
If None then they are extracted all at one (but uses a lot of memory)
ms_before: float, default 2.0
Time in ms to cut before spike peak
ms_after: float
ms_after: float, default 3.0
Time in ms to cut after spike peak
allow_unfiltered: bool
temp_folder: str or Path or None, default: None
If provided, dense waveforms are saved to this temporary folder
allow_unfiltered: bool, default: False
If true, will accept an allow_unfiltered recording.
False by default.
kwargs for sparsity strategy:
{}
Job kwargs:
job kwargs:
{}
Returns
Expand All @@ -1724,18 +1736,38 @@ def precompute_sparsity(
if unit_batch_size is None:
unit_batch_size = len(unit_ids)

mask = np.zeros((len(unit_ids), len(channel_ids)), dtype="bool")

nloop = int(np.ceil((unit_ids.size / unit_batch_size)))
for i in range(nloop):
sl = slice(i * unit_batch_size, (i + 1) * unit_batch_size)
local_ids = unit_ids[sl]
local_sorting = sorting.select_units(local_ids)
local_we = extract_waveforms(
if temp_folder is None:
mask = np.zeros((len(unit_ids), len(channel_ids)), dtype="bool")
nloop = int(np.ceil((unit_ids.size / unit_batch_size)))
for i in range(nloop):
sl = slice(i * unit_batch_size, (i + 1) * unit_batch_size)
local_ids = unit_ids[sl]
local_sorting = sorting.select_units(local_ids)
local_we = extract_waveforms(
recording,
local_sorting,
folder=None,
mode="memory",
precompute_template=("average",),
ms_before=ms_before,
ms_after=ms_after,
max_spikes_per_unit=num_spikes_for_sparsity,
return_scaled=False,
allow_unfiltered=allow_unfiltered,
sparse=False,
**job_kwargs,
)
local_sparsity = compute_sparsity(local_we, **sparse_kwargs)
mask[sl, :] = local_sparsity.mask
else:
temp_folder = Path(temp_folder)
assert (
not temp_folder.is_dir()
), "Temporary folder for pre-computing sparsity already exists. Provide a non-existing folder"
dense_we = extract_waveforms(
recording,
local_sorting,
folder=None,
mode="memory",
sorting,
folder=temp_folder,
precompute_template=("average",),
ms_before=ms_before,
ms_after=ms_after,
Expand All @@ -1745,8 +1777,9 @@ def precompute_sparsity(
sparse=False,
**job_kwargs,
)
local_sparsity = compute_sparsity(local_we, **sparse_kwargs)
mask[sl, :] = local_sparsity.mask
sparsity = compute_sparsity(dense_we, **sparse_kwargs)
mask = sparsity.mask
shutil.rmtree(temp_folder)

sparsity = ChannelSparsity(mask, unit_ids, channel_ids)
return sparsity
Expand Down
2 changes: 0 additions & 2 deletions src/spikeinterface/qualitymetrics/pca_metrics.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Cluster quality metrics computed from principal components."""

from cmath import nan
from copy import deepcopy

import numpy as np
Expand All @@ -16,7 +15,6 @@
except:
pass

import spikeinterface as si
from ..core import get_random_data_chunks, compute_sparsity, WaveformExtractor
from ..core.job_tools import tqdm_joblib
from ..core.template_tools import get_template_extremum_channel
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/qualitymetrics/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
from scipy.stats import norm, multivariate_normal
from scipy.stats import multivariate_normal


def create_ground_truth_pc_distributions(center_locations, total_points):
Expand Down

0 comments on commit 8ec4931

Please sign in to comment.