diff --git a/doc/modules/qualitymetrics/references.rst b/doc/modules/qualitymetrics/references.rst index 8dd8a21548..4f10c7b2b7 100644 --- a/doc/modules/qualitymetrics/references.rst +++ b/doc/modules/qualitymetrics/references.rst @@ -11,6 +11,8 @@ References .. [Hruschka] Hruschka, E.R., de Castro, L.N., Campello R.J.G.B. "Evolutionary algorithms for clustering gene-expression data." Fourth IEEE International Conference on Data Mining (ICDM'04) 2004, pp 403-406. +.. [Gruen] Sonja Grün, Moshe Abeles, and Markus Diesmann. Impact of higher-order correlations on coincidence distributions of massively parallel data. In International School on Neural Networks, Initiated by IIASS and EMFCSC, volume 5286, 96–114. Springer, 2007. + .. [IBL] International Brain Laboratory. “Spike sorting pipeline for the International Brain Laboratory”. 4 May 2022. .. [Jackson] Jadin Jackson, Neil Schmitzer-Torbert, K.D. Harris, and A.D. Redish. Quantitative assessment of extracellular multichannel recording quality using measures of cluster separation. Soc Neurosci Abstr, 518, 01 2005. diff --git a/doc/modules/qualitymetrics/synchrony.rst b/doc/modules/qualitymetrics/synchrony.rst new file mode 100644 index 0000000000..b41e194466 --- /dev/null +++ b/doc/modules/qualitymetrics/synchrony.rst @@ -0,0 +1,49 @@ +Synchrony Metrics (:code:`synchrony`) +===================================== + +Calculation +----------- +This function is providing a metric for the presence of synchronous spiking events across multiple spike trains. + +The complexity is used to characterize synchronous events within the same spike train and across different spike +trains. This way synchronous events can be found both in multi-unit and single-unit spike trains. +Complexity is calculated by counting the number of spikes (i.e. non-empty bins) that occur at the same sample index, +within and across spike trains. + +Synchrony metrics can be computed for different synchrony sizes (>1), defining the number of simultaneous spikes to count. + + + +Expectation and use +------------------- + +A larger value indicates a higher synchrony of the respective spike train with the other spike trains. +Larger values, especially for larger sizes, indicate a higher probability of noisy spikes in spike trains. + +Example code +------------ + +.. code-block:: python + + import spikeinterface.qualitymetrics as qm + # Make recording, sorting and wvf_extractor object for your data. + synchrony = qm.compute_synchrony_metrics(wvf_extractor, synchrony_sizes=(2, 4, 8)) + # synchrony is a tuple of dicts with the synchrony metrics for each unit + + +Links to original implementations +--------------------------------- + +The SpikeInterface implementation is a partial port of the low-level complexity functions from `Elephant - Electrophysiology Analysis Toolkit `_ + +References +---------- + +.. automodule:: spikeinterface.toolkit.qualitymetrics.misc_metrics + + .. autofunction:: compute_synchrony_metrics + +Literature +---------- + +Based on concepts described in Gruen_ diff --git a/src/spikeinterface/core/__init__.py b/src/spikeinterface/core/__init__.py index 5b4a66244e..7c1a3674b5 100644 --- a/src/spikeinterface/core/__init__.py +++ b/src/spikeinterface/core/__init__.py @@ -28,6 +28,7 @@ from .generate import ( generate_recording, generate_sorting, + add_synchrony_to_sorting, create_sorting_npz, generate_snippets, synthesize_random_firings, diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 617a39b6bc..8ab6a2d6b6 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -125,6 +125,31 @@ def generate_sorting( refractory_period_ms=3.0, # in ms seed=None, ): + """ + Generates sorting object with random firings. + + Parameters + ---------- + num_units : int, default: 5 + Number of units + sampling_frequency : float, default: 30000.0 + The sampling frequency + durations : list, default: [10.325, 3.5] + Duration of each segment in s + firing_rates : float, default: 3.0 + The firing rate of each unit (in Hz). + empty_units : list, default: None + List of units that will have no spikes. (used for testing mainly). + refractory_period_ms : float, default: 3.0 + The refractory period in ms + seed : int, default: None + The random seed + + Returns + ------- + sorting : NumpySorting + The sorting object + """ seed = _ensure_seed(seed) num_segments = len(durations) unit_ids = np.arange(num_units) @@ -157,6 +182,59 @@ def generate_sorting( return sorting +def add_synchrony_to_sorting(sorting, sync_event_ratio=0.3, seed=None): + """ + Generates sorting object with added synchronous events from an existing sorting objects. + + Parameters + ---------- + sorting : BaseSorting + The sorting object + sync_event_ratio : float + The ratio of added synchronous spikes with respect to the total number of spikes. + E.g., 0.5 means that the final sorting will have 1.5 times number of spikes, and all the extra + spikes are synchronous (same sample_index), but on different units (not duplicates). + seed : int, default: None + The random seed + + + Returns + ------- + sorting : NumpySorting + The sorting object + + """ + rng = np.random.default_rng(seed) + spikes = sorting.to_spike_vector() + unit_ids = sorting.unit_ids + + # add syncrhonous events + num_sync = int(len(spikes) * sync_event_ratio) + spikes_duplicated = rng.choice(spikes, size=num_sync, replace=True) + # change unit_index + new_unit_indices = np.zeros(len(spikes_duplicated)) + # make sure labels are all unique, keep unit_indices used for each spike + units_used_for_spike = {} + for i, spike in enumerate(spikes_duplicated): + sample_index = spike["sample_index"] + if sample_index not in units_used_for_spike: + units_used_for_spike[sample_index] = np.array([spike["unit_index"]]) + units_not_used = unit_ids[~np.in1d(unit_ids, units_used_for_spike[sample_index])] + + if len(units_not_used) == 0: + continue + new_unit_indices[i] = rng.choice(units_not_used) + units_used_for_spike[sample_index] = np.append(units_used_for_spike[sample_index], new_unit_indices[i]) + spikes_duplicated["unit_index"] = new_unit_indices + spikes_all = np.concatenate((spikes, spikes_duplicated)) + sort_idxs = np.lexsort([spikes_all["sample_index"], spikes_all["segment_index"]]) + spikes_all = spikes_all[sort_idxs] + + sorting = NumpySorting(spikes=spikes_all, sampling_frequency=sorting.sampling_frequency, unit_ids=unit_ids) + + return sorting + + def create_sorting_npz(num_seg, file_path): # create a NPZ sorting file d = {} diff --git a/src/spikeinterface/postprocessing/tests/test_align_sorting.py b/src/spikeinterface/postprocessing/tests/test_align_sorting.py index 0adda426a9..e5c70ae4b2 100644 --- a/src/spikeinterface/postprocessing/tests/test_align_sorting.py +++ b/src/spikeinterface/postprocessing/tests/test_align_sorting.py @@ -6,7 +6,7 @@ import numpy as np -from spikeinterface import WaveformExtractor, load_extractor, extract_waveforms, NumpySorting +from spikeinterface import NumpySorting from spikeinterface.core import generate_sorting from spikeinterface.postprocessing import align_sorting @@ -17,8 +17,8 @@ cache_folder = Path("cache_folder") / "postprocessing" -def test_compute_unit_center_of_mass(): - sorting = generate_sorting(durations=[10.0]) +def test_align_sorting(): + sorting = generate_sorting(durations=[10.0], seed=0) print(sorting) unit_ids = sorting.unit_ids @@ -43,4 +43,4 @@ def test_compute_unit_center_of_mass(): if __name__ == "__main__": - test_compute_unit_center_of_mass() + test_align_sorting() diff --git a/src/spikeinterface/postprocessing/tests/test_correlograms.py b/src/spikeinterface/postprocessing/tests/test_correlograms.py index d6648150de..3d562ba5a0 100644 --- a/src/spikeinterface/postprocessing/tests/test_correlograms.py +++ b/src/spikeinterface/postprocessing/tests/test_correlograms.py @@ -38,7 +38,7 @@ def test_compute_correlograms(self): def test_make_bins(): - sorting = generate_sorting(num_units=5, sampling_frequency=30000.0, durations=[10.325, 3.5]) + sorting = generate_sorting(num_units=5, sampling_frequency=30000.0, durations=[10.325, 3.5], seed=0) window_ms = 43.57 bin_ms = 1.6421 @@ -82,14 +82,14 @@ def test_equal_results_correlograms(): if HAVE_NUMBA: methods.append("numba") - sorting = generate_sorting(num_units=5, sampling_frequency=30000.0, durations=[10.325, 3.5]) + sorting = generate_sorting(num_units=5, sampling_frequency=30000.0, durations=[10.325, 3.5], seed=0) _test_correlograms(sorting, window_ms=60.0, bin_ms=2.0, methods=methods) _test_correlograms(sorting, window_ms=43.57, bin_ms=1.6421, methods=methods) def test_flat_cross_correlogram(): - sorting = generate_sorting(num_units=2, sampling_frequency=10000.0, durations=[100000.0]) + sorting = generate_sorting(num_units=2, sampling_frequency=10000.0, durations=[100000.0], seed=0) methods = ["numpy"] if HAVE_NUMBA: diff --git a/src/spikeinterface/preprocessing/normalize_scale.py b/src/spikeinterface/preprocessing/normalize_scale.py index 90b39aee8a..7d43982853 100644 --- a/src/spikeinterface/preprocessing/normalize_scale.py +++ b/src/spikeinterface/preprocessing/normalize_scale.py @@ -293,7 +293,7 @@ def __init__( means = means[None, :] stds = np.std(random_data, axis=0) stds = stds[None, :] - gain = 1 / stds + gain = 1.0 / stds offset = -means / stds if int_scale is not None: diff --git a/src/spikeinterface/preprocessing/tests/test_normalize_scale.py b/src/spikeinterface/preprocessing/tests/test_normalize_scale.py index b62a73a8cb..764acc9852 100644 --- a/src/spikeinterface/preprocessing/tests/test_normalize_scale.py +++ b/src/spikeinterface/preprocessing/tests/test_normalize_scale.py @@ -78,13 +78,18 @@ def test_zscore(): assert np.all(np.abs(np.mean(tr, axis=0)) < 0.01) assert np.all(np.abs(np.std(tr, axis=0) - 1) < 0.01) + +def test_zscore_int(): + seed = 1 + rec = generate_recording(seed=seed, mode="legacy") rec_int = scale(rec, dtype="int16", gain=100) with pytest.raises(AssertionError): - rec4 = zscore(rec_int, dtype=None) - rec4 = zscore(rec_int, dtype="int16", int_scale=256, mode="mean+std", seed=seed) - tr = rec4.get_traces(segment_index=0) - trace_mean = np.mean(tr, axis=0) - trace_std = np.std(tr, axis=0) + zscore(rec_int, dtype=None) + + zscore_recording = zscore(rec_int, dtype="int16", int_scale=256, mode="mean+std", seed=seed) + traces = zscore_recording.get_traces(segment_index=0) + trace_mean = np.mean(traces, axis=0) + trace_std = np.std(traces, axis=0) assert np.all(np.abs(trace_mean) < 1) assert np.all(np.abs(trace_std - 256) < 1) diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index 4145b4229b..ee28485983 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -499,6 +499,73 @@ def compute_sliding_rp_violations( ) +def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), **kwargs): + """ + Compute synchrony metrics. Synchrony metrics represent the rate of occurrences of + "synchrony_size" spikes at the exact same sample index. + + Parameters + ---------- + waveform_extractor : WaveformExtractor + The waveform extractor object. + synchrony_sizes : list or tuple, default: (2, 4, 8) + The synchrony sizes to compute. + + Returns + ------- + sync_spike_{X} : dict + The synchrony metric for synchrony size X. + Returns are as many as synchrony_sizes. + + References + ---------- + Based on concepts described in [Gruen]_ + This code was adapted from `Elephant - Electrophysiology Analysis Toolkit `_ + """ + assert np.all(s > 1 for s in synchrony_sizes), "Synchrony sizes must be greater than 1" + spike_counts = waveform_extractor.sorting.count_num_spikes_per_unit() + sorting = waveform_extractor.sorting + spikes = sorting.to_spike_vector(concatenated=False) + + # Pre-allocate synchrony counts + synchrony_counts = {} + for synchrony_size in synchrony_sizes: + synchrony_counts[synchrony_size] = np.zeros(len(waveform_extractor.unit_ids), dtype=np.int64) + + for segment_index in range(sorting.get_num_segments()): + spikes_in_segment = spikes[segment_index] + + # we compute just by counting the occurrence of each sample_index + unique_spike_index, complexity = np.unique(spikes_in_segment["sample_index"], return_counts=True) + + # add counts for this segment + for unit_index in np.arange(len(sorting.unit_ids)): + spikes_per_unit = spikes_in_segment[spikes_in_segment["unit_index"] == unit_index] + # some segments/units might have no spikes + if len(spikes_per_unit) == 0: + continue + spike_complexity = complexity[np.in1d(unique_spike_index, spikes_per_unit["sample_index"])] + for synchrony_size in synchrony_sizes: + synchrony_counts[synchrony_size][unit_index] += np.count_nonzero(spike_complexity >= synchrony_size) + + # add counts for this segment + synchrony_metrics_dict = { + f"sync_spike_{synchrony_size}": { + unit_id: synchrony_counts[synchrony_size][unit_index] / spike_counts[unit_id] + for unit_index, unit_id in enumerate(sorting.unit_ids) + } + for synchrony_size in synchrony_sizes + } + + # Convert dict to named tuple + synchrony_metrics_tuple = namedtuple("synchrony_metrics", synchrony_metrics_dict.keys()) + synchrony_metrics = synchrony_metrics_tuple(**synchrony_metrics_dict) + return synchrony_metrics + + +_default_params["synchrony_metrics"] = dict(synchrony_sizes=(0, 2, 4)) + + def compute_amplitude_cutoffs( waveform_extractor, peak_sign="neg", diff --git a/src/spikeinterface/qualitymetrics/quality_metric_list.py b/src/spikeinterface/qualitymetrics/quality_metric_list.py index 185da589fc..90dbb47a3a 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_list.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_list.py @@ -11,6 +11,7 @@ compute_amplitude_cutoffs, compute_amplitude_medians, compute_drift_metrics, + compute_synchrony_metrics, ) from .pca_metrics import ( @@ -39,5 +40,6 @@ "sliding_rp_violation": compute_sliding_rp_violations, "amplitude_cutoff": compute_amplitude_cutoffs, "amplitude_median": compute_amplitude_medians, + "synchrony": compute_synchrony_metrics, "drift": compute_drift_metrics, } diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index 99ca10ba8f..d927d64c4f 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -2,8 +2,8 @@ import shutil from pathlib import Path import numpy as np -from spikeinterface import extract_waveforms, load_waveforms -from spikeinterface.core import NumpySorting, synthetize_spike_train_bad_isi +from spikeinterface import extract_waveforms +from spikeinterface.core import NumpySorting, synthetize_spike_train_bad_isi, add_synchrony_to_sorting from spikeinterface.extractors.toy_example import toy_example from spikeinterface.qualitymetrics.utils import create_ground_truth_pc_distributions @@ -30,6 +30,7 @@ compute_sliding_rp_violations, compute_drift_metrics, compute_amplitude_medians, + compute_synchrony_metrics, ) @@ -65,30 +66,70 @@ def _simulated_data(): return {"duration": max_time, "times": spike_times, "labels": spike_clusters} -def setup_module(): - for folder_name in ("toy_rec", "toy_sorting", "toy_waveforms"): - if (cache_folder / folder_name).is_dir(): - shutil.rmtree(cache_folder / folder_name) +def _waveform_extractor_simple(): + recording, sorting = toy_example(duration=50, seed=10) + recording = recording.save(folder=cache_folder / "rec1") + sorting = sorting.save(folder=cache_folder / "sort1") + folder = cache_folder / "waveform_folder1" + we = extract_waveforms( + recording, + sorting, + folder, + ms_before=3.0, + ms_after=4.0, + max_spikes_per_unit=1000, + n_jobs=1, + chunk_size=30000, + overwrite=True, + ) + _ = compute_principal_components(we, n_components=5, mode="by_channel_local") + return we - recording, sorting = toy_example(num_segments=2, num_units=10) - recording = recording.save(folder=cache_folder / "toy_rec") - sorting = sorting.save(folder=cache_folder / "toy_sorting") +def _waveform_extractor_violations(data): + recording, sorting = toy_example( + duration=[data["duration"]], + spike_times=[data["times"]], + spike_labels=[data["labels"]], + num_segments=1, + num_units=4, + # score_detection=score_detection, + seed=10, + ) + recording = recording.save(folder=cache_folder / "rec2") + sorting = sorting.save(folder=cache_folder / "sort2") + folder = cache_folder / "waveform_folder2" we = extract_waveforms( recording, sorting, - cache_folder / "toy_waveforms", + folder, ms_before=3.0, ms_after=4.0, - max_spikes_per_unit=500, + max_spikes_per_unit=1000, n_jobs=1, chunk_size=30000, + overwrite=True, ) - pca = compute_principal_components(we, n_components=5, mode="by_channel_local") + return we + + +@pytest.fixture(scope="module") +def simulated_data(): + return _simulated_data() + + +@pytest.fixture(scope="module") +def waveform_extractor_violations(simulated_data): + return _waveform_extractor_violations(simulated_data) + +@pytest.fixture(scope="module") +def waveform_extractor_simple(): + return _waveform_extractor_simple() -def test_calculate_pc_metrics(): - we = load_waveforms(cache_folder / "toy_waveforms") + +def test_calculate_pc_metrics(waveform_extractor_simple): + we = waveform_extractor_simple print(we) pca = we.load_extension("principal_components") print(pca) @@ -159,39 +200,8 @@ def test_simplified_silhouette_score_metrics(): assert sim_sil_score1 < sim_sil_score2 -@pytest.fixture -def simulated_data(): - return _simulated_data() - - -def setup_dataset(spike_data, score_detection=1): - # def setup_dataset(spike_data): - recording, sorting = toy_example( - duration=[spike_data["duration"]], - spike_times=[spike_data["times"]], - spike_labels=[spike_data["labels"]], - num_segments=1, - num_units=4, - # score_detection=score_detection, - seed=10, - ) - folder = cache_folder / "waveform_folder2" - we = extract_waveforms( - recording, - sorting, - folder, - ms_before=3.0, - ms_after=4.0, - max_spikes_per_unit=1000, - n_jobs=1, - chunk_size=30000, - overwrite=True, - ) - return we - - -def test_calculate_firing_rate_num_spikes(simulated_data): - we = setup_dataset(simulated_data) +def test_calculate_firing_rate_num_spikes(waveform_extractor_simple): + we = waveform_extractor_simple firing_rates = compute_firing_rates(we) num_spikes = compute_num_spikes(we) @@ -202,8 +212,8 @@ def test_calculate_firing_rate_num_spikes(simulated_data): # np.testing.assert_array_equal(list(num_spikes_gt.values()), list(num_spikes.values())) -def test_calculate_amplitude_cutoff(simulated_data): - we = setup_dataset(simulated_data) +def test_calculate_amplitude_cutoff(waveform_extractor_simple): + we = waveform_extractor_simple spike_amps = compute_spike_amplitudes(we) amp_cuts = compute_amplitude_cutoffs(we, num_histogram_bins=10) print(amp_cuts) @@ -213,19 +223,19 @@ def test_calculate_amplitude_cutoff(simulated_data): # assert np.allclose(list(amp_cuts_gt.values()), list(amp_cuts.values()), rtol=0.05) -def test_calculate_amplitude_median(simulated_data): - we = setup_dataset(simulated_data) +def test_calculate_amplitude_median(waveform_extractor_simple): + we = waveform_extractor_simple spike_amps = compute_spike_amplitudes(we) amp_medians = compute_amplitude_medians(we) - print(amp_medians) + print(spike_amps, amp_medians) # testing method accuracy with magic number is not a good pratcice, I remove this. # amp_medians_gt = {0: 130.77323354628675, 1: 130.7461997791725, 2: 130.7461997791725} # assert np.allclose(list(amp_medians_gt.values()), list(amp_medians.values()), rtol=0.05) -def test_calculate_snrs(simulated_data): - we = setup_dataset(simulated_data) +def test_calculate_snrs(waveform_extractor_simple): + we = waveform_extractor_simple snrs = compute_snrs(we) print(snrs) @@ -234,8 +244,8 @@ def test_calculate_snrs(simulated_data): # assert np.allclose(list(snrs_gt.values()), list(snrs.values()), rtol=0.05) -def test_calculate_presence_ratio(simulated_data): - we = setup_dataset(simulated_data) +def test_calculate_presence_ratio(waveform_extractor_simple): + we = waveform_extractor_simple ratios = compute_presence_ratios(we, bin_duration_s=10) print(ratios) @@ -244,8 +254,8 @@ def test_calculate_presence_ratio(simulated_data): # np.testing.assert_array_equal(list(ratios_gt.values()), list(ratios.values())) -def test_calculate_isi_violations(simulated_data): - we = setup_dataset(simulated_data) +def test_calculate_isi_violations(waveform_extractor_violations): + we = waveform_extractor_violations isi_viol, counts = compute_isi_violations(we, isi_threshold_ms=1, min_isi_ms=0.0) print(isi_viol) @@ -256,8 +266,8 @@ def test_calculate_isi_violations(simulated_data): # np.testing.assert_array_equal(list(counts_gt.values()), list(counts.values())) -def test_calculate_sliding_rp_violations(simulated_data): - we = setup_dataset(simulated_data) +def test_calculate_sliding_rp_violations(waveform_extractor_violations): + we = waveform_extractor_violations contaminations = compute_sliding_rp_violations(we, bin_size_ms=0.25, window_size_s=1) print(contaminations) @@ -266,13 +276,13 @@ def test_calculate_sliding_rp_violations(simulated_data): # assert np.allclose(list(contaminations_gt.values()), list(contaminations.values()), rtol=0.05) -def test_calculate_rp_violations(simulated_data): - counts_gt = {0: 2, 1: 4, 2: 10} - we = setup_dataset(simulated_data) +def test_calculate_rp_violations(waveform_extractor_violations): + we = waveform_extractor_violations rp_contamination, counts = compute_refrac_period_violations(we, refractory_period_ms=1, censored_period_ms=0.0) - print(rp_contamination) + print(rp_contamination, counts) # testing method accuracy with magic number is not a good pratcice, I remove this. + # counts_gt = {0: 2, 1: 4, 2: 10} # rp_contamination_gt = {0: 0.10534956502609294, 1: 1.0, 2: 1.0} # assert np.allclose(list(rp_contamination_gt.values()), list(rp_contamination.values()), rtol=0.05) # np.testing.assert_array_equal(list(counts_gt.values()), list(counts.values())) @@ -286,9 +296,44 @@ def test_calculate_rp_violations(simulated_data): assert np.isnan(rp_contamination[1]) +def test_synchrony_metrics(waveform_extractor_simple): + we = waveform_extractor_simple + sorting = we.sorting + synchrony_sizes = (2, 3, 4) + synchrony_metrics = compute_synchrony_metrics(we, synchrony_sizes=synchrony_sizes) + print(synchrony_metrics) + + # check returns + for size in synchrony_sizes: + assert f"sync_spike_{size}" in synchrony_metrics._fields + + # here we test that increasing added synchrony is captured by syncrhony metrics + added_synchrony_levels = (0.2, 0.5, 0.8) + previous_waveform_extractor = we + for sync_level in added_synchrony_levels: + sorting_sync = add_synchrony_to_sorting(sorting, sync_event_ratio=sync_level) + waveform_extractor_sync = extract_waveforms(previous_waveform_extractor.recording, sorting_sync, mode="memory") + previous_synchrony_metrics = compute_synchrony_metrics( + previous_waveform_extractor, synchrony_sizes=synchrony_sizes + ) + current_synchrony_metrics = compute_synchrony_metrics(waveform_extractor_sync, synchrony_sizes=synchrony_sizes) + print(current_synchrony_metrics) + # check that all values increased + for i, col in enumerate(previous_synchrony_metrics._fields): + assert np.all( + v_prev < v_curr + for (v_prev, v_curr) in zip( + previous_synchrony_metrics[i].values(), current_synchrony_metrics[i].values() + ) + ) + + # set new previous waveform extractor + previous_waveform_extractor = waveform_extractor_sync + + @pytest.mark.sortingcomponents -def test_calculate_drift_metrics(simulated_data): - we = setup_dataset(simulated_data) +def test_calculate_drift_metrics(waveform_extractor_simple): + we = waveform_extractor_simple spike_locs = compute_spike_locations(we) drifts_ptps, drifts_stds, drift_mads = compute_drift_metrics(we, interval_s=10, min_spikes_per_interval=10) @@ -304,11 +349,13 @@ def test_calculate_drift_metrics(simulated_data): if __name__ == "__main__": - setup_module() sim_data = _simulated_data() - test_calculate_amplitude_cutoff(sim_data) - test_calculate_presence_ratio(sim_data) - test_calculate_amplitude_median(sim_data) - test_calculate_isi_violations(sim_data) - test_calculate_sliding_rp_violations(sim_data) - test_calculate_drift_metrics(sim_data) + we = _waveform_extractor_simple() + we_violations = _waveform_extractor_violations(sim_data) + # test_calculate_amplitude_cutoff(we) + # test_calculate_presence_ratio(we) + # test_calculate_amplitude_median(we) + # test_calculate_isi_violations(we) + # test_calculate_sliding_rp_violations(we) + # test_calculate_drift_metrics(we) + test_synchrony_metrics(we)