Skip to content
This repository has been archived by the owner on Jun 6, 2023. It is now read-only.

Commit

Permalink
Implemented noise_overlap quality metric and curator
Browse files Browse the repository at this point in the history
  • Loading branch information
alejoe91 committed Nov 25, 2020
1 parent e02c2fa commit b4dfd2b
Show file tree
Hide file tree
Showing 7 changed files with 165 additions and 22 deletions.
1 change: 1 addition & 0 deletions spiketoolkit/curation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .threshold_metrics import threshold_nn_metrics
from .threshold_metrics import threshold_drift_metrics
from .threshold_metrics import threshold_amplitude_cutoffs
from .threshold_metrics import threshold_noise_overlaps
from ..validation import get_validation_params as get_curation_params
from ..validation.quality_metric_classes.utils.curationsortingextractor import CurationSortingExtractor

96 changes: 96 additions & 0 deletions spiketoolkit/curation/threshold_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from spiketoolkit.validation.quality_metric_classes.presence_ratio import PresenceRatio
from spiketoolkit.validation.quality_metric_classes.isi_violation import ISIViolation
from spiketoolkit.validation.quality_metric_classes.snr import SNR
from spiketoolkit.validation.quality_metric_classes.noise_overlap import NoiseOverlap
from spiketoolkit.validation.quality_metric_classes.isolation_distance import IsolationDistance
from spiketoolkit.validation.quality_metric_classes.nearest_neighbor import NearestNeighbor
from spiketoolkit.validation.quality_metric_classes.drift_metric import DriftMetric
Expand Down Expand Up @@ -374,6 +375,101 @@ def threshold_snrs(
return threshold_sorting


def threshold_noise_overlaps(
sorting,
recording,
threshold,
threshold_sign,
num_features=NoiseOverlap.params['num_features'],
num_knn=NoiseOverlap.params['num_knn'],
max_spikes_per_unit_for_noise_overlap=NoiseOverlap.params['max_spikes_per_unit_for_noise_overlap'],
**kwargs
):
"""
Computes and thresholds the snrs in the sorted dataset with the given sign and value.
Parameters
----------
sorting: SortingExtractor
The sorting result to be evaluated.
recording: RecordingExtractor
The given recording extractor
threshold: int or float
The threshold for the given metric.
threshold_sign: str
If 'less', will threshold any metric less than the given threshold.
If 'less_or_equal', will threshold any metric less than or equal to the given threshold.
If 'greater', will threshold any metric greater than the given threshold.
If 'greater_or_equal', will threshold any metric greater than or equal to the given threshold.
num_features: int
Number of features to use for PCA
num_knn: int
Number of nearest neighbors
max_spikes_per_unit_for_noise_overlap: int
Number of waveforms to use for noise overlaps estimation
**kwargs: keyword arguments
Keyword arguments among the following:
method: str
If 'absolute' (default), amplitudes are absolute amplitudes in uV are returned.
If 'relative', amplitudes are returned as ratios between waveform amplitudes and template amplitudes
peak: str
If maximum channel has to be found among negative peaks ('neg'), positive ('pos') or
both ('both' - default)
frames_before: int
Frames before peak to compute amplitude
frames_after: int
Frames after peak to compute amplitude
apply_filter: bool
If True, recording is bandpass-filtered
freq_min: float
High-pass frequency for optional filter (default 300 Hz)
freq_max: float
Low-pass frequency for optional filter (default 6000 Hz)
grouping_property: str
Property to group channels. E.g. if the recording extractor has the 'group' property and
'grouping_property' is 'group', then waveforms are computed group-wise.
ms_before: float
Time period in ms to cut waveforms before the spike events
ms_after: float
Time period in ms to cut waveforms after the spike events
dtype: dtype
The numpy dtype of the waveforms
compute_property_from_recording: bool
If True and 'grouping_property' is given, the property of each unit is assigned as the corresponding
property of the recording extractor channel on which the average waveform is the largest
max_channels_per_waveforms: int or None
Maximum channels per waveforms to return. If None, all channels are returned
n_jobs: int
Number of parallel jobs (default 1)
memmap: bool
If True, waveforms are saved as memmap object (recommended for long recordings with many channels)
save_property_or_features: bool
If true, it will save features in the sorting extractor
recompute_info: bool
If True, waveforms are recomputed
max_spikes_per_unit: int
The maximum number of spikes to extract per unit
seed: int
Random seed for reproducibility
verbose: bool
If True, will be verbose in metric computation
Returns
----------
threshold sorting extractor
"""
params_dict = update_all_param_dicts_with_kwargs(kwargs)

md = MetricData(sorting=sorting, sampling_frequency=recording.get_sampling_frequency(), recording=recording,
apply_filter=params_dict["apply_filter"], freq_min=params_dict["freq_min"],
duration_in_frames=None, freq_max=params_dict["freq_max"], unit_ids=None, verbose=params_dict['verbose'])

noise_overlap = NoiseOverlap(metric_data=md)
threshold_sorting = noise_overlap.threshold_metric(threshold, threshold_sign, max_spikes_per_unit_for_noise_overlap,
num_features, num_knn, **kwargs)
return threshold_sorting


def threshold_silhouette_scores(
sorting,
recording,
Expand Down
47 changes: 34 additions & 13 deletions spiketoolkit/tests/test_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
threshold_l_ratios,
threshold_amplitude_cutoffs,
threshold_isolation_distances,
threshold_noise_overlaps,
threshold_nn_metrics,
threshold_drift_metrics,
get_curation_params
Expand All @@ -29,6 +30,7 @@
compute_drift_metrics,
compute_silhouette_scores,
compute_isolation_distances,
compute_noise_overlaps,
compute_l_ratios,
compute_d_primes,
compute_nn_metrics,
Expand Down Expand Up @@ -135,6 +137,25 @@ def test_thresh_snrs():
shutil.rmtree('test')


def test_thresh_noise_overlaps():
rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=10, num_channels=4, K=10,
seed=0)

noise_thresh = 0.3

noise_overlaps = compute_noise_overlaps(sort, rec, apply_filter=False, seed=0)
sort_noise = threshold_noise_overlaps(sort, rec, noise_thresh, 'less', apply_filter=False, seed=0)

original_ids = sort.get_unit_ids()
new_noise = []
for unit in sort_noise.get_unit_ids():
new_noise.append(noise_overlaps[original_ids.index(unit)])
new_noise = np.array(new_noise)
assert np.all(new_noise >= noise_thresh)
check_dumping(sort_noise)
shutil.rmtree('test')


# PCA-based
def test_thresh_isolation_distances():
rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=10, num_channels=4, K=10,
Expand Down Expand Up @@ -241,16 +262,16 @@ def test_curation_params():


if __name__ == "__main__":
test_thresh_num_spikes()
test_thresh_presence_ratios()
test_thresh_frs()
test_thresh_isi_violations()

test_thresh_snrs()
test_thresh_amplitude_cutoffs()

test_thresh_silhouettes()
test_thresh_isolation_distances()
test_thresh_l_ratios()
test_thresh_threshold_drift_metrics()
test_thresh_nn_metrics()
# test_thresh_num_spikes()
# test_thresh_presence_ratios()
# test_thresh_frs()
# test_thresh_isi_violations()
#
# test_thresh_snrs()
# test_thresh_amplitude_cutoffs()
test_thresh_noise_overlaps()
# test_thresh_silhouettes()
# test_thresh_isolation_distances()
# test_thresh_l_ratios()
# test_thresh_threshold_drift_metrics()
# test_thresh_nn_metrics()
6 changes: 4 additions & 2 deletions spiketoolkit/tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import numpy as np
from spiketoolkit.validation import compute_isolation_distances, compute_isi_violations, compute_snrs, \
compute_amplitude_cutoffs, compute_d_primes, compute_drift_metrics, compute_firing_rates, compute_l_ratios, \
compute_quality_metrics, compute_nn_metrics, compute_num_spikes, compute_presence_ratios, compute_silhouette_scores, \
get_validation_params
compute_quality_metrics, compute_nn_metrics, compute_num_spikes, compute_presence_ratios, \
compute_silhouette_scores, compute_noise_overlaps, get_validation_params


def test_functions():
Expand All @@ -19,6 +19,7 @@ def test_functions():
iso = compute_isolation_distances(sort, rec, seed=0)
l_ratio = compute_l_ratios(sort, rec, seed=0)
dprime = compute_d_primes(sort, rec, seed=0)
noise_overlaps = compute_noise_overlaps(sort, rec, seed=0)
nn_hit, nn_miss = compute_nn_metrics(sort, rec, seed=0)
snr = compute_snrs(sort, rec, seed=0)
metrics = compute_quality_metrics(sort, rec, return_dict=True, seed=0)
Expand All @@ -35,6 +36,7 @@ def test_functions():
assert np.allclose(metrics['snr'], snr)
assert np.allclose(metrics['max_drift'], max_drift)
assert np.allclose(metrics['cumulative_drift'], cum_drift)
assert np.allclose(metrics['noise_overlap'], noise_overlaps)
assert np.allclose(metrics['nn_hit_rate'], nn_hit)
assert np.allclose(metrics['nn_miss_rate'], nn_miss)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def compute_metric(self, max_spikes_per_unit_for_noise_overlap, num_features, nu
save_property_or_features = params_dict['save_property_or_features']
seed = params_dict['seed']

waveforms = st.postprocessing.get_unit_templates(
waveforms = st.postprocessing.get_unit_waveforms(
self._metric_data._recording,
self._metric_data._sorting,
unit_ids=self._metric_data._unit_ids,
Expand All @@ -39,7 +39,7 @@ def compute_metric(self, max_spikes_per_unit_for_noise_overlap, num_features, nu
np.random.seed(seed)

noise_overlaps = []
for i_u, unit in self._metric_data._unit_ids:
for i_u, unit in enumerate(self._metric_data._unit_ids):
if self._metric_data.verbose:
printProgressBar(i_u + 1, len(self._metric_data._unit_ids))
wfs = waveforms[i_u]
Expand All @@ -50,7 +50,7 @@ def compute_metric(self, max_spikes_per_unit_for_noise_overlap, num_features, nu
wfs = wfs[selecte_idxs]

# get clip_size from waveforms shape
clip_size = wfs.shape[-1] // 2
clip_size = wfs.shape[-1]

num_clips = len(wfs)
min_time = np.min(times)
Expand Down
29 changes: 26 additions & 3 deletions spiketoolkit/validation/quality_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ def compute_snrs(
return snrs


def compute_noise_overlap(
def compute_noise_overlaps(
sorting,
recording,
num_features=NoiseOverlap.params['num_features'],
Expand All @@ -398,6 +398,11 @@ def compute_noise_overlap(
):
"""
Computes and returns the noise overlaps in the sorted dataset.
Noise overlap estimates the fraction of ‘‘noise events’’ in a cluster, i.e., above-threshold events not associated
with true firings of this or any of the other clustered units. A large noise overlap implies a high false-positive
rate.
Implementation from ml_ms4alg. For more information see https://doi.org/10.1016/j.neuron.2017.08.030
Parameters
----------
Expand Down Expand Up @@ -1053,6 +1058,9 @@ def compute_quality_metrics(
max_spikes_per_unit_for_snr=SNR.params['max_spikes_per_unit_for_snr'],
template_mode=SNR.params['template_mode'],
max_channel_peak=SNR.params['max_channel_peak'],
max_spikes_per_unit_for_noise_overlap=NoiseOverlap.params['max_spikes_per_unit_for_noise_overlap'],
noise_overlap_num_features=NoiseOverlap.params['num_features'],
noise_overlap_num_knn=NoiseOverlap.params['num_knn'],
drift_metrics_interval_s=DriftMetric.params['drift_metrics_interval_s'],
drift_metrics_min_spikes_per_interval=DriftMetric.params['drift_metrics_min_spikes_per_interval'],
max_spikes_for_silhouette=SilhouetteScore.params['max_spikes_for_silhouette'],
Expand Down Expand Up @@ -1090,11 +1098,17 @@ def compute_quality_metrics(
snr_noise_duration: float
Number of seconds to compute noise level from (default 10.0)
max_spikes_per_unit_for_snr: int
Maximum number of spikes to compute templates from (default 1000)
Maximum number of spikes to compute templates for SNR from (default 1000)
template_mode: str
Use 'mean' or 'median' to compute templates
max_channel_peak: str
If maximum channel has to be found among negative peaks ('neg'), positive ('pos') or both ('both' - default)
max_spikes_per_unit_for_noise_overlap: int
Maximum number of spikes to compute templates for noise overlap from (default 1000)
noise_overlap_num_features: int
Number of features to use for PCA for noise overlap
noise_overlap_num_knn: int
Number of nearest neighbors for noise overlap
drift_metrics_interval_s: float
Time period for evaluating drift.
drift_metrics_min_spikes_per_interval: int
Expand Down Expand Up @@ -1187,7 +1201,8 @@ def compute_quality_metrics(
if "firing_rate" in metric_names or "presence_ratio" in metric_names or "isi_violation" in metric_names:
if recording is None and duration_in_frames is None:
raise ValueError(
"duration_in_frames and recording cannot both be None when computing firing_rate, presence_ratio, and isi_violation")
"duration_in_frames and recording cannot both be None when computing firing_rate, "
"presence_ratio, and isi_violation")

if "max_drift" in metric_names or "cumulative_drift" in metric_names or "silhouette_score" in metric_names \
or "isolation_distance" in metric_names or "l_ratio" in metric_names or "d_prime" in metric_names \
Expand Down Expand Up @@ -1258,6 +1273,14 @@ def compute_quality_metrics(
**kwargs)
metrics_dict['isolation_distance'] = isolation_distances

if "noise_overlap" in metric_names:
noise_overlap = NoiseOverlap(metric_data=md)
noise_overlaps = noise_overlap.compute_metric(max_spikes_per_unit_for_noise_overlap,
noise_overlap_num_features,
noise_overlap_num_knn,
**kwargs)
metrics_dict['noise_overlap'] = noise_overlaps

if "l_ratio" in metric_names:
l_ratio = LRatio(metric_data=md)
l_ratios = l_ratio.compute_metric(num_channels_to_compare, max_spikes_per_cluster, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion spiketoolkit/validation/validation_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
compute_l_ratios,
compute_d_primes,
compute_nn_metrics,
compute_noise_overlap,
compute_noise_overlaps,
compute_quality_metrics,
get_quality_metrics_list
)
Expand Down

0 comments on commit b4dfd2b

Please sign in to comment.