From 546831b7258492a2ccbe1db54fc57f6ed19e3726 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 31 Aug 2023 19:59:40 +0200 Subject: [PATCH 1/8] Implement syncrhrony metrics without elephant --- doc/modules/qualitymetrics/synchrony.rst | 49 +++++++++++++++++++ .../qualitymetrics/misc_metrics.py | 46 +++++++++++++++++ .../qualitymetrics/quality_metric_list.py | 2 + 3 files changed, 97 insertions(+) create mode 100644 doc/modules/qualitymetrics/synchrony.rst diff --git a/doc/modules/qualitymetrics/synchrony.rst b/doc/modules/qualitymetrics/synchrony.rst new file mode 100644 index 0000000000..d826138ad6 --- /dev/null +++ b/doc/modules/qualitymetrics/synchrony.rst @@ -0,0 +1,49 @@ +Synchrony Metrics (:code:`synchrony`) +===================================== + +Calculation +----------- +This function is providing a metric for the presence of synchronous spiking events across multiple spike trains. + +The complexity is used to characterize synchronous events within the same spike train and across different spike +trains. This way synchronous events can be found both in multi-unit and single-unit spike trains. + +Complexity is calculated by counting the number of spikes (i.e. non-empty bins) that occur separated by spread - 1 or less empty bins, +within and across spike trains in the spiketrains list. + +Expectation and use +------------------- +A larger value indicates a higher synchrony of the respective spike train with the other spike trains. + +Example code +------------ + +.. code-block:: python + + import spikeinterface.qualitymetrics as qm + # Make recording, sorting and wvf_extractor object for your data. + presence_ratio = qm.compute_synchrony_metrics(wvf_extractor) + # presence_ratio is a tuple of dicts with the synchrony metrics for each unit + +Links to source code +-------------------- + +From `Elephant - Electrophysiology Analysis Toolkit `_ + + +References +---------- + +.. automodule:: spikeinterface.toolkit.qualitymetrics.misc_metrics + + .. autofunction:: compute_synchrony_metrics + +Literature +---------- + +Described in Gruen_ + +Citations +--------- +.. [Gruen] Sonja Grün, Moshe Abeles, and Markus Diesmann. Impact of higher-order correlations on coincidence distributions of massively parallel data. +In International School on Neural Networks, Initiated by IIASS and EMFCSC, volume 5286, 96–114. Springer, 2007. diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index 778de8aea4..158854e195 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -498,6 +498,52 @@ def compute_sliding_rp_violations( ) +def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), **kwargs): + spike_counts = waveform_extractor.sorting.count_num_spikes_per_unit() + sorting = waveform_extractor.sorting + spikes = sorting.to_spike_vector(concatenated=False) + + # Pre-allocate synchrony counts + synchrony_counts = {} + for synchrony_size in synchrony_sizes: + synchrony_counts[synchrony_size] = np.zeros(len(waveform_extractor.unit_ids), dtype=np.int64) + + for segment_index in range(sorting.get_num_segments()): + num_samples = waveform_extractor.get_num_samples(segment_index) + spikes_in_segment = spikes[segment_index] + + # we compute the complexity as an histogram with a single sample as bin + bins = np.arange(0, num_samples + 1) + complexity = np.histogram(spikes_in_segment["sample_index"], bins)[0] + + # add counts for this segment + for unit_index in np.arange(len(sorting.unit_ids)): + spikes_per_unit = spikes_in_segment[spikes_in_segment["unit_index"] == unit_index] + # some segments/units might have no spikes + if len(spikes_per_unit) == 0: + continue + spike_complexity = complexity[spikes_per_unit["sample_index"]] + for synchrony_size in synchrony_sizes: + synchrony_counts[synchrony_size][unit_index] += np.count_nonzero(spike_complexity >= synchrony_size) + + # add counts for this segment + synchrony_metrics_dict = { + f"sync_spike_{synchrony_size}": { + unit_id: synchrony_counts[synchrony_size][unit_index] / spike_counts[unit_id] + for unit_index, unit_id in enumerate(sorting.unit_ids) + } + for synchrony_size in synchrony_sizes + } + + # Convert dict to named tuple + synchrony_metrics_tuple = namedtuple("synchrony_metrics", synchrony_metrics_dict.keys()) + synchrony_metrics = synchrony_metrics_tuple(**synchrony_metrics_dict) + return synchrony_metrics + + +_default_params["synchrony_metrics"] = dict(synchrony_sizes=(0, 2, 4)) + + def compute_amplitude_cutoffs( waveform_extractor, peak_sign="neg", diff --git a/src/spikeinterface/qualitymetrics/quality_metric_list.py b/src/spikeinterface/qualitymetrics/quality_metric_list.py index 185da589fc..90dbb47a3a 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_list.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_list.py @@ -11,6 +11,7 @@ compute_amplitude_cutoffs, compute_amplitude_medians, compute_drift_metrics, + compute_synchrony_metrics, ) from .pca_metrics import ( @@ -39,5 +40,6 @@ "sliding_rp_violation": compute_sliding_rp_violations, "amplitude_cutoff": compute_amplitude_cutoffs, "amplitude_median": compute_amplitude_medians, + "synchrony": compute_synchrony_metrics, "drift": compute_drift_metrics, } From db44a10532db5f65c4b2caacf1a8cffe9f10de5a Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 1 Sep 2023 10:27:09 +0200 Subject: [PATCH 2/8] Update docstring, doc, and references --- doc/modules/qualitymetrics/references.rst | 2 ++ doc/modules/qualitymetrics/synchrony.rst | 26 +++++++++---------- .../qualitymetrics/misc_metrics.py | 23 ++++++++++++++++ 3 files changed, 38 insertions(+), 13 deletions(-) diff --git a/doc/modules/qualitymetrics/references.rst b/doc/modules/qualitymetrics/references.rst index 8dd8a21548..4f10c7b2b7 100644 --- a/doc/modules/qualitymetrics/references.rst +++ b/doc/modules/qualitymetrics/references.rst @@ -11,6 +11,8 @@ References .. [Hruschka] Hruschka, E.R., de Castro, L.N., Campello R.J.G.B. "Evolutionary algorithms for clustering gene-expression data." Fourth IEEE International Conference on Data Mining (ICDM'04) 2004, pp 403-406. +.. [Gruen] Sonja Grün, Moshe Abeles, and Markus Diesmann. Impact of higher-order correlations on coincidence distributions of massively parallel data. In International School on Neural Networks, Initiated by IIASS and EMFCSC, volume 5286, 96–114. Springer, 2007. + .. [IBL] International Brain Laboratory. “Spike sorting pipeline for the International Brain Laboratory”. 4 May 2022. .. [Jackson] Jadin Jackson, Neil Schmitzer-Torbert, K.D. Harris, and A.D. Redish. Quantitative assessment of extracellular multichannel recording quality using measures of cluster separation. Soc Neurosci Abstr, 518, 01 2005. diff --git a/doc/modules/qualitymetrics/synchrony.rst b/doc/modules/qualitymetrics/synchrony.rst index d826138ad6..71f4579e30 100644 --- a/doc/modules/qualitymetrics/synchrony.rst +++ b/doc/modules/qualitymetrics/synchrony.rst @@ -7,13 +7,18 @@ This function is providing a metric for the presence of synchronous spiking even The complexity is used to characterize synchronous events within the same spike train and across different spike trains. This way synchronous events can be found both in multi-unit and single-unit spike trains. +Complexity is calculated by counting the number of spikes (i.e. non-empty bins) that occur at the same sample index, +within and across spike trains. + +Synchrony metrics can be computed for different syncrony sizes (>1), defining the number of simultanous spikes to count. + -Complexity is calculated by counting the number of spikes (i.e. non-empty bins) that occur separated by spread - 1 or less empty bins, -within and across spike trains in the spiketrains list. Expectation and use ------------------- + A larger value indicates a higher synchrony of the respective spike train with the other spike trains. +Higher values, especially for high sizes, indicate a higher probability of noisy spikes in spike trains. Example code ------------ @@ -22,14 +27,14 @@ Example code import spikeinterface.qualitymetrics as qm # Make recording, sorting and wvf_extractor object for your data. - presence_ratio = qm.compute_synchrony_metrics(wvf_extractor) - # presence_ratio is a tuple of dicts with the synchrony metrics for each unit + synchrony = qm.compute_synchrony_metrics(wvf_extractor, synchrony_sizes=(2, 4, 8)) + # synchrony is a tuple of dicts with the synchrony metrics for each unit -Links to source code --------------------- -From `Elephant - Electrophysiology Analysis Toolkit `_ +Links to original implementations +--------------------------------- +The SpikeInterface implementation is a partial port of the low-level complexity functions from `Elephant - Electrophysiology Analysis Toolkit `_ References ---------- @@ -41,9 +46,4 @@ References Literature ---------- -Described in Gruen_ - -Citations ---------- -.. [Gruen] Sonja Grün, Moshe Abeles, and Markus Diesmann. Impact of higher-order correlations on coincidence distributions of massively parallel data. -In International School on Neural Networks, Initiated by IIASS and EMFCSC, volume 5286, 96–114. Springer, 2007. +Based on concepts described in Gruen_ diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index 158854e195..066e202e39 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -499,6 +499,29 @@ def compute_sliding_rp_violations( def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), **kwargs): + """ + Compute synchrony metrics. Synchrony metrics represent the rate of occurrences of + "synchrony_size" spikes at the exact same sample index. + + Parameters + ---------- + waveform_extractor : WaveformExtractor + The waveform extractor object. + synchrony_sizes : list or tuple, default: (2, 4, 8) + The synchrony sizes to compute. + + Returns + ------- + sync_spike_{X} : dict + The synchrony metric for synchrony size X. + Returns are as many as synchrony_sizes. + + References + ---------- + Based on concepts described in [Gruen]_ + This code was adapted from `Elephant - Electrophysiology Analysis Toolkit `_ + """ + assert np.all(s > 1 for s in synchrony_sizes), "Synchrony sizes must be greater than 1" spike_counts = waveform_extractor.sorting.count_num_spikes_per_unit() sorting = waveform_extractor.sorting spikes = sorting.to_spike_vector(concatenated=False) From 87c6386c90ffe9840971bbfc00ebdd56f5052955 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 1 Sep 2023 16:23:30 +0200 Subject: [PATCH 3/8] Add syncnrhony function, add and optimize quality metrics tests --- src/spikeinterface/core/__init__.py | 1 + src/spikeinterface/core/generate.py | 78 +++++++ .../qualitymetrics/misc_metrics.py | 8 +- .../tests/test_metrics_functions.py | 193 +++++++++++------- 4 files changed, 202 insertions(+), 78 deletions(-) diff --git a/src/spikeinterface/core/__init__.py b/src/spikeinterface/core/__init__.py index 5b4a66244e..7c1a3674b5 100644 --- a/src/spikeinterface/core/__init__.py +++ b/src/spikeinterface/core/__init__.py @@ -28,6 +28,7 @@ from .generate import ( generate_recording, generate_sorting, + add_synchrony_to_sorting, create_sorting_npz, generate_snippets, synthesize_random_firings, diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 93b9459b5f..0f318d2b3d 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -120,6 +120,31 @@ def generate_sorting( refractory_period_ms=3.0, # in ms seed=None, ): + """ + Generates sorting object with random firings. + + Parameters + ---------- + num_units : int, default: 5 + Number of units + sampling_frequency : float, default: 30000.0 + The sampling frequency + durations : list, default: [10.325, 3.5] + Duration of each segment in s + firing_rates : float, default: 3.0 + The firing rate of each unit (in Hz). + empty_units : list, default: None + List of units to remove from the sorting + refractory_period_ms : float, default: 3.0 + The refractory period in ms + seed : int, default: None + The random seed + + Returns + ------- + sorting : NumpySorting + The sorting object + """ seed = _ensure_seed(seed) num_segments = len(durations) unit_ids = np.arange(num_units) @@ -152,6 +177,59 @@ def generate_sorting( return sorting +def add_synchrony_to_sorting(sorting, sync_event_ratio=0.3, seed=None): + """ + Generates sorting object with added synchronous events from an existing sorting objects. + + Parameters + ---------- + sorting : BaseSorting + The sorting object + sync_event_ratio : float + The ratio of added synchronous spikes with respect to the total number of spikes. + E.g., 0.5 means that the final sorting will have 1.5 times number of spikes, and all the extra + spikes are synchronous (same sample_index), but on different units (not duplicates). + seed : int, default: None + The random seed + + + Returns + ------- + sorting : NumpySorting + The sorting object + + """ + rng = np.random.default_rng(seed) + spikes = sorting.to_spike_vector() + unit_ids = sorting.unit_ids + + # add syncrhonous events + num_sync = int(len(spikes) * sync_event_ratio) + spikes_duplicated = rng.choice(spikes, size=num_sync, replace=True) + # change unit_index + new_unit_indices = np.zeros(len(spikes_duplicated)) + # make sure labels are all unique, keep unit_indices used for each spike + units_used_for_spike = {} + for i, spike in enumerate(spikes_duplicated): + sample_index = spike["sample_index"] + if sample_index not in units_used_for_spike: + units_used_for_spike[sample_index] = np.array([spike["unit_index"]]) + units_not_used = unit_ids[~np.in1d(unit_ids, units_used_for_spike[sample_index])] + + if len(units_not_used) == 0: + continue + new_unit_indices[i] = rng.choice(units_not_used) + units_used_for_spike[sample_index] = np.append(units_used_for_spike[sample_index], new_unit_indices[i]) + spikes_duplicated["unit_index"] = new_unit_indices + spikes_all = np.concatenate((spikes, spikes_duplicated)) + sort_idxs = np.lexsort([spikes_all["sample_index"], spikes_all["segment_index"]]) + spikes_all = spikes_all[sort_idxs] + + sorting = NumpySorting(spikes=spikes_all, sampling_frequency=sorting.sampling_frequency, unit_ids=unit_ids) + + return sorting + + def create_sorting_npz(num_seg, file_path): # create a NPZ sorting file d = {} diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index 066e202e39..83f6ecc244 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -532,12 +532,10 @@ def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), **k synchrony_counts[synchrony_size] = np.zeros(len(waveform_extractor.unit_ids), dtype=np.int64) for segment_index in range(sorting.get_num_segments()): - num_samples = waveform_extractor.get_num_samples(segment_index) spikes_in_segment = spikes[segment_index] - # we compute the complexity as an histogram with a single sample as bin - bins = np.arange(0, num_samples + 1) - complexity = np.histogram(spikes_in_segment["sample_index"], bins)[0] + # we compute just by counting the occurrence of each sample_index + unique_spike_index, complexity = np.unique(spikes_in_segment["sample_index"], return_counts=True) # add counts for this segment for unit_index in np.arange(len(sorting.unit_ids)): @@ -545,7 +543,7 @@ def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), **k # some segments/units might have no spikes if len(spikes_per_unit) == 0: continue - spike_complexity = complexity[spikes_per_unit["sample_index"]] + spike_complexity = complexity[np.in1d(unique_spike_index, spikes_per_unit["sample_index"])] for synchrony_size in synchrony_sizes: synchrony_counts[synchrony_size][unit_index] += np.count_nonzero(spike_complexity >= synchrony_size) diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index 99ca10ba8f..d927d64c4f 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -2,8 +2,8 @@ import shutil from pathlib import Path import numpy as np -from spikeinterface import extract_waveforms, load_waveforms -from spikeinterface.core import NumpySorting, synthetize_spike_train_bad_isi +from spikeinterface import extract_waveforms +from spikeinterface.core import NumpySorting, synthetize_spike_train_bad_isi, add_synchrony_to_sorting from spikeinterface.extractors.toy_example import toy_example from spikeinterface.qualitymetrics.utils import create_ground_truth_pc_distributions @@ -30,6 +30,7 @@ compute_sliding_rp_violations, compute_drift_metrics, compute_amplitude_medians, + compute_synchrony_metrics, ) @@ -65,30 +66,70 @@ def _simulated_data(): return {"duration": max_time, "times": spike_times, "labels": spike_clusters} -def setup_module(): - for folder_name in ("toy_rec", "toy_sorting", "toy_waveforms"): - if (cache_folder / folder_name).is_dir(): - shutil.rmtree(cache_folder / folder_name) +def _waveform_extractor_simple(): + recording, sorting = toy_example(duration=50, seed=10) + recording = recording.save(folder=cache_folder / "rec1") + sorting = sorting.save(folder=cache_folder / "sort1") + folder = cache_folder / "waveform_folder1" + we = extract_waveforms( + recording, + sorting, + folder, + ms_before=3.0, + ms_after=4.0, + max_spikes_per_unit=1000, + n_jobs=1, + chunk_size=30000, + overwrite=True, + ) + _ = compute_principal_components(we, n_components=5, mode="by_channel_local") + return we - recording, sorting = toy_example(num_segments=2, num_units=10) - recording = recording.save(folder=cache_folder / "toy_rec") - sorting = sorting.save(folder=cache_folder / "toy_sorting") +def _waveform_extractor_violations(data): + recording, sorting = toy_example( + duration=[data["duration"]], + spike_times=[data["times"]], + spike_labels=[data["labels"]], + num_segments=1, + num_units=4, + # score_detection=score_detection, + seed=10, + ) + recording = recording.save(folder=cache_folder / "rec2") + sorting = sorting.save(folder=cache_folder / "sort2") + folder = cache_folder / "waveform_folder2" we = extract_waveforms( recording, sorting, - cache_folder / "toy_waveforms", + folder, ms_before=3.0, ms_after=4.0, - max_spikes_per_unit=500, + max_spikes_per_unit=1000, n_jobs=1, chunk_size=30000, + overwrite=True, ) - pca = compute_principal_components(we, n_components=5, mode="by_channel_local") + return we + + +@pytest.fixture(scope="module") +def simulated_data(): + return _simulated_data() + + +@pytest.fixture(scope="module") +def waveform_extractor_violations(simulated_data): + return _waveform_extractor_violations(simulated_data) + +@pytest.fixture(scope="module") +def waveform_extractor_simple(): + return _waveform_extractor_simple() -def test_calculate_pc_metrics(): - we = load_waveforms(cache_folder / "toy_waveforms") + +def test_calculate_pc_metrics(waveform_extractor_simple): + we = waveform_extractor_simple print(we) pca = we.load_extension("principal_components") print(pca) @@ -159,39 +200,8 @@ def test_simplified_silhouette_score_metrics(): assert sim_sil_score1 < sim_sil_score2 -@pytest.fixture -def simulated_data(): - return _simulated_data() - - -def setup_dataset(spike_data, score_detection=1): - # def setup_dataset(spike_data): - recording, sorting = toy_example( - duration=[spike_data["duration"]], - spike_times=[spike_data["times"]], - spike_labels=[spike_data["labels"]], - num_segments=1, - num_units=4, - # score_detection=score_detection, - seed=10, - ) - folder = cache_folder / "waveform_folder2" - we = extract_waveforms( - recording, - sorting, - folder, - ms_before=3.0, - ms_after=4.0, - max_spikes_per_unit=1000, - n_jobs=1, - chunk_size=30000, - overwrite=True, - ) - return we - - -def test_calculate_firing_rate_num_spikes(simulated_data): - we = setup_dataset(simulated_data) +def test_calculate_firing_rate_num_spikes(waveform_extractor_simple): + we = waveform_extractor_simple firing_rates = compute_firing_rates(we) num_spikes = compute_num_spikes(we) @@ -202,8 +212,8 @@ def test_calculate_firing_rate_num_spikes(simulated_data): # np.testing.assert_array_equal(list(num_spikes_gt.values()), list(num_spikes.values())) -def test_calculate_amplitude_cutoff(simulated_data): - we = setup_dataset(simulated_data) +def test_calculate_amplitude_cutoff(waveform_extractor_simple): + we = waveform_extractor_simple spike_amps = compute_spike_amplitudes(we) amp_cuts = compute_amplitude_cutoffs(we, num_histogram_bins=10) print(amp_cuts) @@ -213,19 +223,19 @@ def test_calculate_amplitude_cutoff(simulated_data): # assert np.allclose(list(amp_cuts_gt.values()), list(amp_cuts.values()), rtol=0.05) -def test_calculate_amplitude_median(simulated_data): - we = setup_dataset(simulated_data) +def test_calculate_amplitude_median(waveform_extractor_simple): + we = waveform_extractor_simple spike_amps = compute_spike_amplitudes(we) amp_medians = compute_amplitude_medians(we) - print(amp_medians) + print(spike_amps, amp_medians) # testing method accuracy with magic number is not a good pratcice, I remove this. # amp_medians_gt = {0: 130.77323354628675, 1: 130.7461997791725, 2: 130.7461997791725} # assert np.allclose(list(amp_medians_gt.values()), list(amp_medians.values()), rtol=0.05) -def test_calculate_snrs(simulated_data): - we = setup_dataset(simulated_data) +def test_calculate_snrs(waveform_extractor_simple): + we = waveform_extractor_simple snrs = compute_snrs(we) print(snrs) @@ -234,8 +244,8 @@ def test_calculate_snrs(simulated_data): # assert np.allclose(list(snrs_gt.values()), list(snrs.values()), rtol=0.05) -def test_calculate_presence_ratio(simulated_data): - we = setup_dataset(simulated_data) +def test_calculate_presence_ratio(waveform_extractor_simple): + we = waveform_extractor_simple ratios = compute_presence_ratios(we, bin_duration_s=10) print(ratios) @@ -244,8 +254,8 @@ def test_calculate_presence_ratio(simulated_data): # np.testing.assert_array_equal(list(ratios_gt.values()), list(ratios.values())) -def test_calculate_isi_violations(simulated_data): - we = setup_dataset(simulated_data) +def test_calculate_isi_violations(waveform_extractor_violations): + we = waveform_extractor_violations isi_viol, counts = compute_isi_violations(we, isi_threshold_ms=1, min_isi_ms=0.0) print(isi_viol) @@ -256,8 +266,8 @@ def test_calculate_isi_violations(simulated_data): # np.testing.assert_array_equal(list(counts_gt.values()), list(counts.values())) -def test_calculate_sliding_rp_violations(simulated_data): - we = setup_dataset(simulated_data) +def test_calculate_sliding_rp_violations(waveform_extractor_violations): + we = waveform_extractor_violations contaminations = compute_sliding_rp_violations(we, bin_size_ms=0.25, window_size_s=1) print(contaminations) @@ -266,13 +276,13 @@ def test_calculate_sliding_rp_violations(simulated_data): # assert np.allclose(list(contaminations_gt.values()), list(contaminations.values()), rtol=0.05) -def test_calculate_rp_violations(simulated_data): - counts_gt = {0: 2, 1: 4, 2: 10} - we = setup_dataset(simulated_data) +def test_calculate_rp_violations(waveform_extractor_violations): + we = waveform_extractor_violations rp_contamination, counts = compute_refrac_period_violations(we, refractory_period_ms=1, censored_period_ms=0.0) - print(rp_contamination) + print(rp_contamination, counts) # testing method accuracy with magic number is not a good pratcice, I remove this. + # counts_gt = {0: 2, 1: 4, 2: 10} # rp_contamination_gt = {0: 0.10534956502609294, 1: 1.0, 2: 1.0} # assert np.allclose(list(rp_contamination_gt.values()), list(rp_contamination.values()), rtol=0.05) # np.testing.assert_array_equal(list(counts_gt.values()), list(counts.values())) @@ -286,9 +296,44 @@ def test_calculate_rp_violations(simulated_data): assert np.isnan(rp_contamination[1]) +def test_synchrony_metrics(waveform_extractor_simple): + we = waveform_extractor_simple + sorting = we.sorting + synchrony_sizes = (2, 3, 4) + synchrony_metrics = compute_synchrony_metrics(we, synchrony_sizes=synchrony_sizes) + print(synchrony_metrics) + + # check returns + for size in synchrony_sizes: + assert f"sync_spike_{size}" in synchrony_metrics._fields + + # here we test that increasing added synchrony is captured by syncrhony metrics + added_synchrony_levels = (0.2, 0.5, 0.8) + previous_waveform_extractor = we + for sync_level in added_synchrony_levels: + sorting_sync = add_synchrony_to_sorting(sorting, sync_event_ratio=sync_level) + waveform_extractor_sync = extract_waveforms(previous_waveform_extractor.recording, sorting_sync, mode="memory") + previous_synchrony_metrics = compute_synchrony_metrics( + previous_waveform_extractor, synchrony_sizes=synchrony_sizes + ) + current_synchrony_metrics = compute_synchrony_metrics(waveform_extractor_sync, synchrony_sizes=synchrony_sizes) + print(current_synchrony_metrics) + # check that all values increased + for i, col in enumerate(previous_synchrony_metrics._fields): + assert np.all( + v_prev < v_curr + for (v_prev, v_curr) in zip( + previous_synchrony_metrics[i].values(), current_synchrony_metrics[i].values() + ) + ) + + # set new previous waveform extractor + previous_waveform_extractor = waveform_extractor_sync + + @pytest.mark.sortingcomponents -def test_calculate_drift_metrics(simulated_data): - we = setup_dataset(simulated_data) +def test_calculate_drift_metrics(waveform_extractor_simple): + we = waveform_extractor_simple spike_locs = compute_spike_locations(we) drifts_ptps, drifts_stds, drift_mads = compute_drift_metrics(we, interval_s=10, min_spikes_per_interval=10) @@ -304,11 +349,13 @@ def test_calculate_drift_metrics(simulated_data): if __name__ == "__main__": - setup_module() sim_data = _simulated_data() - test_calculate_amplitude_cutoff(sim_data) - test_calculate_presence_ratio(sim_data) - test_calculate_amplitude_median(sim_data) - test_calculate_isi_violations(sim_data) - test_calculate_sliding_rp_violations(sim_data) - test_calculate_drift_metrics(sim_data) + we = _waveform_extractor_simple() + we_violations = _waveform_extractor_violations(sim_data) + # test_calculate_amplitude_cutoff(we) + # test_calculate_presence_ratio(we) + # test_calculate_amplitude_median(we) + # test_calculate_isi_violations(we) + # test_calculate_sliding_rp_violations(we) + # test_calculate_drift_metrics(we) + test_synchrony_metrics(we) From 95bd819a2f46cca9396c3e9e614dc7b428149d1d Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 1 Sep 2023 16:25:56 +0200 Subject: [PATCH 4/8] Update doc/modules/qualitymetrics/synchrony.rst Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- doc/modules/qualitymetrics/synchrony.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/modules/qualitymetrics/synchrony.rst b/doc/modules/qualitymetrics/synchrony.rst index 71f4579e30..8769882fa5 100644 --- a/doc/modules/qualitymetrics/synchrony.rst +++ b/doc/modules/qualitymetrics/synchrony.rst @@ -10,7 +10,7 @@ trains. This way synchronous events can be found both in multi-unit and single-u Complexity is calculated by counting the number of spikes (i.e. non-empty bins) that occur at the same sample index, within and across spike trains. -Synchrony metrics can be computed for different syncrony sizes (>1), defining the number of simultanous spikes to count. +Synchrony metrics can be computed for different synchrony sizes (>1), defining the number of simultaneous spikes to count. From c4c566d001165cd423f18c6699906298a94323ef Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 1 Sep 2023 16:26:03 +0200 Subject: [PATCH 5/8] Update doc/modules/qualitymetrics/synchrony.rst Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- doc/modules/qualitymetrics/synchrony.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/modules/qualitymetrics/synchrony.rst b/doc/modules/qualitymetrics/synchrony.rst index 8769882fa5..b41e194466 100644 --- a/doc/modules/qualitymetrics/synchrony.rst +++ b/doc/modules/qualitymetrics/synchrony.rst @@ -18,7 +18,7 @@ Expectation and use ------------------- A larger value indicates a higher synchrony of the respective spike train with the other spike trains. -Higher values, especially for high sizes, indicate a higher probability of noisy spikes in spike trains. +Larger values, especially for larger sizes, indicate a higher probability of noisy spikes in spike trains. Example code ------------ From 023778cb5cc8eff62f4b0c86df7f62735f315f67 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 4 Sep 2023 15:03:39 +0200 Subject: [PATCH 6/8] Update src/spikeinterface/core/generate.py Co-authored-by: Garcia Samuel --- src/spikeinterface/core/generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 0f318d2b3d..bbf77682ee 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -134,7 +134,7 @@ def generate_sorting( firing_rates : float, default: 3.0 The firing rate of each unit (in Hz). empty_units : list, default: None - List of units to remove from the sorting + List of units that will have no spikes. (used for testing mainly). refractory_period_ms : float, default: 3.0 The refractory period in ms seed : int, default: None From ae6099cdc455bfe0d76a085781abc89a62bee780 Mon Sep 17 00:00:00 2001 From: Zach McKenzie <92116279+zm711@users.noreply.github.com> Date: Wed, 6 Sep 2023 00:28:18 -0400 Subject: [PATCH 7/8] Fix the [full] install for Macs (#1955) * Fix for mac install. * update doc comments for dependencies --- pyproject.toml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3ecfbe2718..e17d6f6506 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,7 +91,7 @@ full = [ "networkx", "distinctipy", "matplotlib", - "cuda-python; sys_platform != 'darwin'", + "cuda-python; platform_system != 'Darwin'", "numba", ] @@ -151,9 +151,9 @@ docs = [ # for notebooks in the gallery "MEArec", # Use as an example "datalad==0.16.2", # Download mearec data, not sure if needed as is installed with conda as well because of git-annex - "pandas", # Don't know where this is needed - "hdbscan>=0.8.33", # For sorters, probably spikingcircus - "numba", # For sorters, probably spikingcircus + "pandas", # in the modules gallery comparison tutorial + "hdbscan>=0.8.33", # For sorters spykingcircus2 + tridesclous + "numba", # For many postprocessing functions # for release we need pypi, so this needs to be commented "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", # We always build from the latest version "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", # We always build from the latest version From 6ed5a09ca6a6e18c4a0eaeaec88638389c1b2c1e Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 7 Sep 2023 09:20:51 +0200 Subject: [PATCH 8/8] Fix seeds in postprocessing tests --- .../postprocessing/tests/test_align_sorting.py | 8 ++++---- .../postprocessing/tests/test_correlograms.py | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/postprocessing/tests/test_align_sorting.py b/src/spikeinterface/postprocessing/tests/test_align_sorting.py index 0adda426a9..e5c70ae4b2 100644 --- a/src/spikeinterface/postprocessing/tests/test_align_sorting.py +++ b/src/spikeinterface/postprocessing/tests/test_align_sorting.py @@ -6,7 +6,7 @@ import numpy as np -from spikeinterface import WaveformExtractor, load_extractor, extract_waveforms, NumpySorting +from spikeinterface import NumpySorting from spikeinterface.core import generate_sorting from spikeinterface.postprocessing import align_sorting @@ -17,8 +17,8 @@ cache_folder = Path("cache_folder") / "postprocessing" -def test_compute_unit_center_of_mass(): - sorting = generate_sorting(durations=[10.0]) +def test_align_sorting(): + sorting = generate_sorting(durations=[10.0], seed=0) print(sorting) unit_ids = sorting.unit_ids @@ -43,4 +43,4 @@ def test_compute_unit_center_of_mass(): if __name__ == "__main__": - test_compute_unit_center_of_mass() + test_align_sorting() diff --git a/src/spikeinterface/postprocessing/tests/test_correlograms.py b/src/spikeinterface/postprocessing/tests/test_correlograms.py index d6648150de..3d562ba5a0 100644 --- a/src/spikeinterface/postprocessing/tests/test_correlograms.py +++ b/src/spikeinterface/postprocessing/tests/test_correlograms.py @@ -38,7 +38,7 @@ def test_compute_correlograms(self): def test_make_bins(): - sorting = generate_sorting(num_units=5, sampling_frequency=30000.0, durations=[10.325, 3.5]) + sorting = generate_sorting(num_units=5, sampling_frequency=30000.0, durations=[10.325, 3.5], seed=0) window_ms = 43.57 bin_ms = 1.6421 @@ -82,14 +82,14 @@ def test_equal_results_correlograms(): if HAVE_NUMBA: methods.append("numba") - sorting = generate_sorting(num_units=5, sampling_frequency=30000.0, durations=[10.325, 3.5]) + sorting = generate_sorting(num_units=5, sampling_frequency=30000.0, durations=[10.325, 3.5], seed=0) _test_correlograms(sorting, window_ms=60.0, bin_ms=2.0, methods=methods) _test_correlograms(sorting, window_ms=43.57, bin_ms=1.6421, methods=methods) def test_flat_cross_correlogram(): - sorting = generate_sorting(num_units=2, sampling_frequency=10000.0, durations=[100000.0]) + sorting = generate_sorting(num_units=2, sampling_frequency=10000.0, durations=[100000.0], seed=0) methods = ["numpy"] if HAVE_NUMBA: