Skip to content
This repository has been archived by the owner on Jun 6, 2023. It is now read-only.

Commit

Permalink
Merge pull request #417 from SpikeInterface/noise_overlap
Browse files Browse the repository at this point in the history
Noise overlap metric
  • Loading branch information
alejoe91 authored Dec 9, 2020
2 parents d65940d + b4dfd2b commit 04ebb6b
Show file tree
Hide file tree
Showing 7 changed files with 409 additions and 36 deletions.
1 change: 1 addition & 0 deletions spiketoolkit/curation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .threshold_metrics import threshold_nn_metrics
from .threshold_metrics import threshold_drift_metrics
from .threshold_metrics import threshold_amplitude_cutoffs
from .threshold_metrics import threshold_noise_overlaps
from ..validation import get_validation_params as get_curation_params
from ..validation.quality_metric_classes.utils.curationsortingextractor import CurationSortingExtractor

96 changes: 96 additions & 0 deletions spiketoolkit/curation/threshold_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from spiketoolkit.validation.quality_metric_classes.presence_ratio import PresenceRatio
from spiketoolkit.validation.quality_metric_classes.isi_violation import ISIViolation
from spiketoolkit.validation.quality_metric_classes.snr import SNR
from spiketoolkit.validation.quality_metric_classes.noise_overlap import NoiseOverlap
from spiketoolkit.validation.quality_metric_classes.isolation_distance import IsolationDistance
from spiketoolkit.validation.quality_metric_classes.nearest_neighbor import NearestNeighbor
from spiketoolkit.validation.quality_metric_classes.drift_metric import DriftMetric
Expand Down Expand Up @@ -374,6 +375,101 @@ def threshold_snrs(
return threshold_sorting


def threshold_noise_overlaps(
sorting,
recording,
threshold,
threshold_sign,
num_features=NoiseOverlap.params['num_features'],
num_knn=NoiseOverlap.params['num_knn'],
max_spikes_per_unit_for_noise_overlap=NoiseOverlap.params['max_spikes_per_unit_for_noise_overlap'],
**kwargs
):
"""
Computes and thresholds the snrs in the sorted dataset with the given sign and value.
Parameters
----------
sorting: SortingExtractor
The sorting result to be evaluated.
recording: RecordingExtractor
The given recording extractor
threshold: int or float
The threshold for the given metric.
threshold_sign: str
If 'less', will threshold any metric less than the given threshold.
If 'less_or_equal', will threshold any metric less than or equal to the given threshold.
If 'greater', will threshold any metric greater than the given threshold.
If 'greater_or_equal', will threshold any metric greater than or equal to the given threshold.
num_features: int
Number of features to use for PCA
num_knn: int
Number of nearest neighbors
max_spikes_per_unit_for_noise_overlap: int
Number of waveforms to use for noise overlaps estimation
**kwargs: keyword arguments
Keyword arguments among the following:
method: str
If 'absolute' (default), amplitudes are absolute amplitudes in uV are returned.
If 'relative', amplitudes are returned as ratios between waveform amplitudes and template amplitudes
peak: str
If maximum channel has to be found among negative peaks ('neg'), positive ('pos') or
both ('both' - default)
frames_before: int
Frames before peak to compute amplitude
frames_after: int
Frames after peak to compute amplitude
apply_filter: bool
If True, recording is bandpass-filtered
freq_min: float
High-pass frequency for optional filter (default 300 Hz)
freq_max: float
Low-pass frequency for optional filter (default 6000 Hz)
grouping_property: str
Property to group channels. E.g. if the recording extractor has the 'group' property and
'grouping_property' is 'group', then waveforms are computed group-wise.
ms_before: float
Time period in ms to cut waveforms before the spike events
ms_after: float
Time period in ms to cut waveforms after the spike events
dtype: dtype
The numpy dtype of the waveforms
compute_property_from_recording: bool
If True and 'grouping_property' is given, the property of each unit is assigned as the corresponding
property of the recording extractor channel on which the average waveform is the largest
max_channels_per_waveforms: int or None
Maximum channels per waveforms to return. If None, all channels are returned
n_jobs: int
Number of parallel jobs (default 1)
memmap: bool
If True, waveforms are saved as memmap object (recommended for long recordings with many channels)
save_property_or_features: bool
If true, it will save features in the sorting extractor
recompute_info: bool
If True, waveforms are recomputed
max_spikes_per_unit: int
The maximum number of spikes to extract per unit
seed: int
Random seed for reproducibility
verbose: bool
If True, will be verbose in metric computation
Returns
----------
threshold sorting extractor
"""
params_dict = update_all_param_dicts_with_kwargs(kwargs)

md = MetricData(sorting=sorting, sampling_frequency=recording.get_sampling_frequency(), recording=recording,
apply_filter=params_dict["apply_filter"], freq_min=params_dict["freq_min"],
duration_in_frames=None, freq_max=params_dict["freq_max"], unit_ids=None, verbose=params_dict['verbose'])

noise_overlap = NoiseOverlap(metric_data=md)
threshold_sorting = noise_overlap.threshold_metric(threshold, threshold_sign, max_spikes_per_unit_for_noise_overlap,
num_features, num_knn, **kwargs)
return threshold_sorting


def threshold_silhouette_scores(
sorting,
recording,
Expand Down
47 changes: 34 additions & 13 deletions spiketoolkit/tests/test_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
threshold_l_ratios,
threshold_amplitude_cutoffs,
threshold_isolation_distances,
threshold_noise_overlaps,
threshold_nn_metrics,
threshold_drift_metrics,
get_curation_params
Expand All @@ -29,6 +30,7 @@
compute_drift_metrics,
compute_silhouette_scores,
compute_isolation_distances,
compute_noise_overlaps,
compute_l_ratios,
compute_d_primes,
compute_nn_metrics,
Expand Down Expand Up @@ -135,6 +137,25 @@ def test_thresh_snrs():
shutil.rmtree('test')


def test_thresh_noise_overlaps():
rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=10, num_channels=4, K=10,
seed=0)

noise_thresh = 0.3

noise_overlaps = compute_noise_overlaps(sort, rec, apply_filter=False, seed=0)
sort_noise = threshold_noise_overlaps(sort, rec, noise_thresh, 'less', apply_filter=False, seed=0)

original_ids = sort.get_unit_ids()
new_noise = []
for unit in sort_noise.get_unit_ids():
new_noise.append(noise_overlaps[original_ids.index(unit)])
new_noise = np.array(new_noise)
assert np.all(new_noise >= noise_thresh)
check_dumping(sort_noise)
shutil.rmtree('test')


# PCA-based
def test_thresh_isolation_distances():
rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=10, num_channels=4, K=10,
Expand Down Expand Up @@ -241,16 +262,16 @@ def test_curation_params():


if __name__ == "__main__":
test_thresh_num_spikes()
test_thresh_presence_ratios()
test_thresh_frs()
test_thresh_isi_violations()

test_thresh_snrs()
test_thresh_amplitude_cutoffs()

test_thresh_silhouettes()
test_thresh_isolation_distances()
test_thresh_l_ratios()
test_thresh_threshold_drift_metrics()
test_thresh_nn_metrics()
# test_thresh_num_spikes()
# test_thresh_presence_ratios()
# test_thresh_frs()
# test_thresh_isi_violations()
#
# test_thresh_snrs()
# test_thresh_amplitude_cutoffs()
test_thresh_noise_overlaps()
# test_thresh_silhouettes()
# test_thresh_isolation_distances()
# test_thresh_l_ratios()
# test_thresh_threshold_drift_metrics()
# test_thresh_nn_metrics()
6 changes: 4 additions & 2 deletions spiketoolkit/tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import numpy as np
from spiketoolkit.validation import compute_isolation_distances, compute_isi_violations, compute_snrs, \
compute_amplitude_cutoffs, compute_d_primes, compute_drift_metrics, compute_firing_rates, compute_l_ratios, \
compute_quality_metrics, compute_nn_metrics, compute_num_spikes, compute_presence_ratios, compute_silhouette_scores, \
get_validation_params
compute_quality_metrics, compute_nn_metrics, compute_num_spikes, compute_presence_ratios, \
compute_silhouette_scores, compute_noise_overlaps, get_validation_params


def test_functions():
Expand All @@ -19,6 +19,7 @@ def test_functions():
iso = compute_isolation_distances(sort, rec, seed=0)
l_ratio = compute_l_ratios(sort, rec, seed=0)
dprime = compute_d_primes(sort, rec, seed=0)
noise_overlaps = compute_noise_overlaps(sort, rec, seed=0)
nn_hit, nn_miss = compute_nn_metrics(sort, rec, seed=0)
snr = compute_snrs(sort, rec, seed=0)
metrics = compute_quality_metrics(sort, rec, return_dict=True, seed=0)
Expand All @@ -35,6 +36,7 @@ def test_functions():
assert np.allclose(metrics['snr'], snr)
assert np.allclose(metrics['max_drift'], max_drift)
assert np.allclose(metrics['cumulative_drift'], cum_drift)
assert np.allclose(metrics['noise_overlap'], noise_overlaps)
assert np.allclose(metrics['nn_hit_rate'], nn_hit)
assert np.allclose(metrics['nn_miss_rate'], nn_miss)

Expand Down
130 changes: 130 additions & 0 deletions spiketoolkit/validation/quality_metric_classes/noise_overlap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import numpy as np
from copy import copy
from .utils.thresholdcurator import ThresholdCurator
from .quality_metric import QualityMetric
import spiketoolkit as st
from spikemetrics.utils import printProgressBar
from collections import OrderedDict
from sklearn.neighbors import NearestNeighbors
from .parameter_dictionaries import update_all_param_dicts_with_kwargs


class NoiseOverlap(QualityMetric):
installed = True # check at class level if installed or not
installation_mesg = "" # err
params = OrderedDict([('max_spikes_per_unit_for_noise_overlap', 1000), ('num_features', 10),
('num_knn', 6)])
curator_name = "ThresholdNoiseOverlaps"

def __init__(self, metric_data):
QualityMetric.__init__(self, metric_data, metric_name="noise_overlap")

if not metric_data.has_recording():
raise ValueError("MetricData object must have a recording")

def compute_metric(self, max_spikes_per_unit_for_noise_overlap, num_features, num_knn, **kwargs):
params_dict = update_all_param_dicts_with_kwargs(kwargs)
save_property_or_features = params_dict['save_property_or_features']
seed = params_dict['seed']

waveforms = st.postprocessing.get_unit_waveforms(
self._metric_data._recording,
self._metric_data._sorting,
unit_ids=self._metric_data._unit_ids,
max_spikes_per_unit=max_spikes_per_unit_for_noise_overlap,
**kwargs
)

if seed is not None:
np.random.seed(seed)

noise_overlaps = []
for i_u, unit in enumerate(self._metric_data._unit_ids):
if self._metric_data.verbose:
printProgressBar(i_u + 1, len(self._metric_data._unit_ids))
wfs = waveforms[i_u]
times = self._metric_data._sorting.get_unit_spike_train(unit_id=unit)

if len(wfs) > max_spikes_per_unit_for_noise_overlap:
selecte_idxs = np.random.choice(times, size=max_spikes_per_unit_for_noise_overlap)
wfs = wfs[selecte_idxs]

# get clip_size from waveforms shape
clip_size = wfs.shape[-1]

num_clips = len(wfs)
min_time = np.min(times)
max_time = np.max(times)
times_control = np.random.choice(np.arange(min_time, max_time), size=num_clips)
clips = copy(wfs)
clips_control = np.stack(self._metric_data._recording.get_snippets(snippet_len=clip_size,
reference_frames=times_control))
template = np.median(wfs, axis=0)
max_ind = np.unravel_index(np.argmax(np.abs(template)), template.shape)
chmax = max_ind[0]
tmax = max_ind[1]
max_val = template[chmax, tmax]
weighted_clips_control = np.zeros(clips_control.shape)
weights = np.zeros(num_clips)
for j in range(num_clips):
clip0 = clips_control[j, :, :]
val0 = clip0[chmax, tmax]
weight0 = val0 * max_val
weights[j] = weight0
weighted_clips_control[j, :, :] = clip0 * weight0

noise_template = np.sum(weighted_clips_control, axis=0)
noise_template = noise_template / np.sum(np.abs(noise_template)) * np.sum(np.abs(template))

for j in range(num_clips):
clips[j, :, :] = _subtract_clip_component(clips[j, :, :], noise_template)
clips_control[j, :, :] = _subtract_clip_component(clips_control[j, :, :], noise_template)

all_clips = np.concatenate([clips, clips_control], axis=0)
num_channels_wfs = all_clips.shape[1]
num_samples_wfs = all_clips.shape[2]
all_features = _compute_pca_features(all_clips.reshape((num_clips * 2,
num_channels_wfs * num_samples_wfs)), num_features)

distances, indices = NearestNeighbors(n_neighbors=num_knn + 1, algorithm='auto').fit(
all_features.T).kneighbors()

group_id = np.zeros((num_clips * 2))
group_id[0:num_clips] = 1
group_id[num_clips:] = 2
num_match = 0
total = 0
for j in range(num_clips * 2):
for k in range(1, num_knn + 1):
ind = indices[j][k]
if group_id[j] == group_id[ind]:
num_match = num_match + 1
total = total + 1
pct_match = num_match / total
noise_overlap = 1 - pct_match
noise_overlaps.append(noise_overlap)
noise_overlaps = np.asarray(noise_overlaps)
if save_property_or_features:
self.save_property_or_features(self._metric_data._sorting, noise_overlaps, self._metric_name)
return noise_overlaps

def threshold_metric(self, threshold, threshold_sign, max_spikes_per_unit_for_noise_overlap,
num_features, num_knn, **kwargs):
noise_overlaps = self.compute_metric(max_spikes_per_unit_for_noise_overlap, num_features, num_knn, **kwargs)
threshold_curator = ThresholdCurator(sorting=self._metric_data._sorting, metric=noise_overlaps)
threshold_curator.threshold_sorting(threshold=threshold, threshold_sign=threshold_sign)
return threshold_curator


def _compute_pca_features(X, num_components):
u, s, vt = np.linalg.svd(X)
return u[:, :num_components].T


def _subtract_clip_component(clip1, component):
V1 = clip1.flatten()
V2 = component.flatten()
V1 = V1 - np.mean(V1)
V2 = V2 - np.mean(V2)
V1 = V1 - V2 * np.dot(V1, V2) / np.dot(V2, V2)
return V1.reshape(clip1.shape)
Loading

0 comments on commit 04ebb6b

Please sign in to comment.