From c61378ec1d3fad9b5f9d6521a51a0dd5376dfc35 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Wed, 24 Jul 2024 11:45:17 +0100 Subject: [PATCH 1/9] simplify qm tests and wrap for multiprocessing --- .../qualitymetrics/tests/conftest.py | 51 +++++++++++ .../tests/test_metrics_functions.py | 85 +++++++------------ .../qualitymetrics/tests/test_pca_metrics.py | 5 ++ .../tests/test_quality_metric_calculator.py | 53 +----------- 4 files changed, 92 insertions(+), 102 deletions(-) diff --git a/src/spikeinterface/qualitymetrics/tests/conftest.py b/src/spikeinterface/qualitymetrics/tests/conftest.py index bb2a345340..b1b23fcaee 100644 --- a/src/spikeinterface/qualitymetrics/tests/conftest.py +++ b/src/spikeinterface/qualitymetrics/tests/conftest.py @@ -5,6 +5,8 @@ create_sorting_analyzer, ) +job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") + def _small_sorting_analyzer(): recording, sorting = generate_ground_truth_recording( @@ -35,3 +37,52 @@ def _small_sorting_analyzer(): @pytest.fixture(scope="module") def small_sorting_analyzer(): return _small_sorting_analyzer() + + +def _sorting_analyzer_simple(): + + # we need high firing rate for amplitude_cutoff + recording, sorting = generate_ground_truth_recording( + durations=[ + 120.0, + ], + sampling_frequency=30_000.0, + num_channels=6, + num_units=10, + generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0), + generate_unit_locations_kwargs=dict( + margin_um=5.0, + minimum_z=5.0, + maximum_z=20.0, + ), + generate_templates_kwargs=dict( + unit_params=dict( + alpha=(200.0, 500.0), + ) + ), + noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"), + seed=1205, + ) + + sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=True) + + sorting_analyzer.compute("random_spikes", max_spikes_per_unit=300, seed=1205) + sorting_analyzer.compute("noise_levels") + sorting_analyzer.compute("waveforms", **job_kwargs) + sorting_analyzer.compute("templates") + sorting_analyzer.compute("spike_amplitudes", **job_kwargs) + + return sorting_analyzer + + +@pytest.fixture(scope="module") +def sorting_analyzer_simple(): + sorting_analyzer = get_sorting_analyzer(seed=2205) + return sorting_analyzer + + return sorting_analyzer + + +@pytest.fixture(scope="module") +def sorting_analyzer_simple(): + return _sorting_analyzer_simple() diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index bb222200e9..446007d10b 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -135,36 +135,6 @@ def test_unit_id_order_independence(small_sorting_analyzer): assert quality_metrics_2[metric][1] == metric_1_data["#4"] -def _sorting_analyzer_simple(): - recording, sorting = generate_ground_truth_recording( - durations=[ - 50.0, - ], - sampling_frequency=30_000.0, - num_channels=6, - num_units=10, - generate_sorting_kwargs=dict(firing_rates=6.0, refractory_period_ms=4.0), - noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"), - seed=2205, - ) - - sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=True) - - sorting_analyzer.compute("random_spikes", max_spikes_per_unit=300, seed=2205) - sorting_analyzer.compute("noise_levels") - sorting_analyzer.compute("waveforms", **job_kwargs) - sorting_analyzer.compute("templates") - sorting_analyzer.compute("principal_components", n_components=5, mode="by_channel_local", **job_kwargs) - sorting_analyzer.compute("spike_amplitudes", **job_kwargs) - - return sorting_analyzer - - -@pytest.fixture(scope="module") -def sorting_analyzer_simple(): - return _sorting_analyzer_simple() - - def _sorting_violation(): max_time = 100.0 sampling_frequency = 30000 @@ -570,27 +540,36 @@ def test_calculate_sd_ratio(sorting_analyzer_simple): if __name__ == "__main__": - sorting_analyzer = _sorting_analyzer_simple() - print(sorting_analyzer) - - test_unit_structure_in_output(_small_sorting_analyzer()) - - # test_calculate_firing_rate_num_spikes(sorting_analyzer) - # test_calculate_snrs(sorting_analyzer) - # test_calculate_amplitude_cutoff(sorting_analyzer) - # test_calculate_presence_ratio(sorting_analyzer) - # test_calculate_amplitude_median(sorting_analyzer) - # test_calculate_sliding_rp_violations(sorting_analyzer) - # test_calculate_drift_metrics(sorting_analyzer) - # test_synchrony_metrics(sorting_analyzer) - # test_synchrony_metrics_unit_id_subset(sorting_analyzer) - # test_synchrony_metrics_no_unit_ids(sorting_analyzer) - # test_calculate_firing_range(sorting_analyzer) - # test_calculate_amplitude_cv_metrics(sorting_analyzer) - # test_calculate_sd_ratio(sorting_analyzer) - - # sorting_analyzer_violations = _sorting_analyzer_violations() + test_unit_structure_in_output(small_sorting_analyzer) + test_unit_id_order_independence(small_sorting_analyzer) + + test_synchrony_counts_no_sync() + test_synchrony_counts_one_sync() + test_synchrony_counts_one_quad_sync() + test_synchrony_counts_not_all_units() + + test_mahalanobis_metrics() + test_lda_metrics() + test_nearest_neighbors_metrics() + test_silhouette_score_metrics() + test_simplified_silhouette_score_metrics() + + test_calculate_firing_rate_num_spikes(sorting_analyzer_simple) + test_calculate_snrs(sorting_analyzer) + test_calculate_amplitude_cutoff(sorting_analyzer) + test_calculate_presence_ratio(sorting_analyzer) + test_calculate_amplitude_median(sorting_analyzer) + test_calculate_sliding_rp_violations(sorting_analyzer) + test_calculate_drift_metrics(sorting_analyzer) + test_synchrony_metrics(sorting_analyzer) + test_synchrony_metrics_unit_id_subset(sorting_analyzer) + test_synchrony_metrics_no_unit_ids(sorting_analyzer) + test_calculate_firing_range(sorting_analyzer) + test_calculate_amplitude_cv_metrics(sorting_analyzer) + test_calculate_sd_ratio(sorting_analyzer) + + sorting_analyzer_violations = _sorting_analyzer_violations() # print(sorting_analyzer_violations) - # test_calculate_isi_violations(sorting_analyzer_violations) - # test_calculate_sliding_rp_violations(sorting_analyzer_violations) - # test_calculate_rp_violations(sorting_analyzer_violations) + test_calculate_isi_violations(sorting_analyzer_violations) + test_calculate_sliding_rp_violations(sorting_analyzer_violations) + test_calculate_rp_violations(sorting_analyzer_violations) diff --git a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py index 6ddeb02689..a0fc97c37c 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py @@ -22,3 +22,8 @@ def test_calculate_pc_metrics(small_sorting_analyzer): assert not np.all(np.isnan(res2[metric_name].values)) assert np.array_equal(res1[metric_name].values, res2[metric_name].values) + + +if __name__ == "__main__": + + test_calculate_pc_metrics(small_sorting_analyzer) diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py index 28869ba5ff..f877f12708 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py @@ -2,7 +2,6 @@ from pathlib import Path import numpy as np - from spikeinterface.core import ( generate_ground_truth_recording, create_sorting_analyzer, @@ -15,51 +14,9 @@ compute_quality_metrics, ) - job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") -def get_sorting_analyzer(seed=2205): - # we need high firing rate for amplitude_cutoff - recording, sorting = generate_ground_truth_recording( - durations=[ - 120.0, - ], - sampling_frequency=30_000.0, - num_channels=6, - num_units=10, - generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0), - generate_unit_locations_kwargs=dict( - margin_um=5.0, - minimum_z=5.0, - maximum_z=20.0, - ), - generate_templates_kwargs=dict( - unit_params=dict( - alpha=(200.0, 500.0), - ) - ), - noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"), - seed=seed, - ) - - sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=True) - - sorting_analyzer.compute("random_spikes", max_spikes_per_unit=300, seed=seed) - sorting_analyzer.compute("noise_levels") - sorting_analyzer.compute("waveforms", **job_kwargs) - sorting_analyzer.compute("templates") - sorting_analyzer.compute("spike_amplitudes", **job_kwargs) - - return sorting_analyzer - - -@pytest.fixture(scope="module") -def sorting_analyzer_simple(): - sorting_analyzer = get_sorting_analyzer(seed=2205) - return sorting_analyzer - - def test_compute_quality_metrics(sorting_analyzer_simple): sorting_analyzer = sorting_analyzer_simple print(sorting_analyzer) @@ -118,6 +75,7 @@ def test_compute_quality_metrics_recordingless(sorting_analyzer_simple): ) for metric_name in metrics.columns: + print(metric_name) if metric_name == "sd_ratio": # this one need recording!!! continue @@ -291,9 +249,6 @@ def test_empty_units(sorting_analyzer_simple): if __name__ == "__main__": - sorting_analyzer = get_sorting_analyzer() - print(sorting_analyzer) - - test_compute_quality_metrics(sorting_analyzer) - test_compute_quality_metrics_recordingless(sorting_analyzer) - test_empty_units(sorting_analyzer) + test_compute_quality_metrics(sorting_analyzer_simple) + test_compute_quality_metrics_recordingless(sorting_analyzer_simple) + test_empty_units(sorting_analyzer_simple) From df1564b877c26e099754f412066e5f24b0008d58 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Wed, 24 Jul 2024 13:18:00 +0100 Subject: [PATCH 2/9] try -1 instead of 2 --- src/spikeinterface/qualitymetrics/tests/conftest.py | 2 +- .../qualitymetrics/tests/test_metrics_functions.py | 2 +- src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py | 2 +- .../qualitymetrics/tests/test_quality_metric_calculator.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/qualitymetrics/tests/conftest.py b/src/spikeinterface/qualitymetrics/tests/conftest.py index b1b23fcaee..4b55a25b4a 100644 --- a/src/spikeinterface/qualitymetrics/tests/conftest.py +++ b/src/spikeinterface/qualitymetrics/tests/conftest.py @@ -5,7 +5,7 @@ create_sorting_analyzer, ) -job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") +job_kwargs = dict(n_jobs=-1, progress_bar=True, chunk_duration="1s") def _small_sorting_analyzer(): diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index 446007d10b..0df1c25586 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -44,7 +44,7 @@ from spikeinterface.core.basesorting import minimum_spike_dtype -job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") +job_kwargs = dict(n_jobs=-1, progress_bar=True, chunk_duration="1s") def test_unit_structure_in_output(small_sorting_analyzer): diff --git a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py index a0fc97c37c..507b9a1f70 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py @@ -13,7 +13,7 @@ def test_calculate_pc_metrics(small_sorting_analyzer): res1 = compute_pc_metrics(sorting_analyzer, n_jobs=1, progress_bar=True, seed=1205) res1 = pd.DataFrame(res1) - res2 = compute_pc_metrics(sorting_analyzer, n_jobs=2, progress_bar=True, seed=1205) + res2 = compute_pc_metrics(sorting_analyzer, n_jobs=-1, progress_bar=True, seed=1205) res2 = pd.DataFrame(res2) for metric_name in res1.columns: diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py index f877f12708..da1f08c536 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py @@ -14,7 +14,7 @@ compute_quality_metrics, ) -job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") +job_kwargs = dict(n_jobs=-1, progress_bar=True, chunk_duration="1s") def test_compute_quality_metrics(sorting_analyzer_simple): From b5d896d2d0e996d0d375583f87c817e837963e23 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Wed, 24 Jul 2024 14:50:04 +0100 Subject: [PATCH 3/9] delete print statement --- .../qualitymetrics/tests/test_quality_metric_calculator.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py index da1f08c536..0756596654 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py @@ -75,7 +75,6 @@ def test_compute_quality_metrics_recordingless(sorting_analyzer_simple): ) for metric_name in metrics.columns: - print(metric_name) if metric_name == "sd_ratio": # this one need recording!!! continue From 9d3aa2a7d14edb72c5729b1e4606519f1f63b3e3 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Thu, 25 Jul 2024 09:44:10 +0100 Subject: [PATCH 4/9] go back to n_jobs=1 --- .../qualitymetrics/tests/test_metrics_functions.py | 2 +- src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py | 2 +- .../qualitymetrics/tests/test_quality_metric_calculator.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index 0df1c25586..446007d10b 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -44,7 +44,7 @@ from spikeinterface.core.basesorting import minimum_spike_dtype -job_kwargs = dict(n_jobs=-1, progress_bar=True, chunk_duration="1s") +job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") def test_unit_structure_in_output(small_sorting_analyzer): diff --git a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py index 507b9a1f70..a0fc97c37c 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py @@ -13,7 +13,7 @@ def test_calculate_pc_metrics(small_sorting_analyzer): res1 = compute_pc_metrics(sorting_analyzer, n_jobs=1, progress_bar=True, seed=1205) res1 = pd.DataFrame(res1) - res2 = compute_pc_metrics(sorting_analyzer, n_jobs=-1, progress_bar=True, seed=1205) + res2 = compute_pc_metrics(sorting_analyzer, n_jobs=2, progress_bar=True, seed=1205) res2 = pd.DataFrame(res2) for metric_name in res1.columns: diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py index 0756596654..616e6c90c1 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py @@ -14,7 +14,7 @@ compute_quality_metrics, ) -job_kwargs = dict(n_jobs=-1, progress_bar=True, chunk_duration="1s") +job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") def test_compute_quality_metrics(sorting_analyzer_simple): From 79f0206f73f654ead63c11f6065e10533d0309ea Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Thu, 25 Jul 2024 13:50:52 +0100 Subject: [PATCH 5/9] Respond to Joe review --- .../qualitymetrics/tests/conftest.py | 25 ++----------- .../tests/test_metrics_functions.py | 37 ------------------- .../qualitymetrics/tests/test_pca_metrics.py | 5 --- .../tests/test_quality_metric_calculator.py | 6 --- 4 files changed, 4 insertions(+), 69 deletions(-) diff --git a/src/spikeinterface/qualitymetrics/tests/conftest.py b/src/spikeinterface/qualitymetrics/tests/conftest.py index 4b55a25b4a..01fa16c8d7 100644 --- a/src/spikeinterface/qualitymetrics/tests/conftest.py +++ b/src/spikeinterface/qualitymetrics/tests/conftest.py @@ -5,10 +5,11 @@ create_sorting_analyzer, ) -job_kwargs = dict(n_jobs=-1, progress_bar=True, chunk_duration="1s") +job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") -def _small_sorting_analyzer(): +@pytest.fixture(scope="module") +def small_sorting_analyzer(): recording, sorting = generate_ground_truth_recording( durations=[2.0], num_units=10, @@ -35,12 +36,7 @@ def _small_sorting_analyzer(): @pytest.fixture(scope="module") -def small_sorting_analyzer(): - return _small_sorting_analyzer() - - -def _sorting_analyzer_simple(): - +def sorting_analyzer_simple(): # we need high firing rate for amplitude_cutoff recording, sorting = generate_ground_truth_recording( durations=[ @@ -73,16 +69,3 @@ def _sorting_analyzer_simple(): sorting_analyzer.compute("spike_amplitudes", **job_kwargs) return sorting_analyzer - - -@pytest.fixture(scope="module") -def sorting_analyzer_simple(): - sorting_analyzer = get_sorting_analyzer(seed=2205) - return sorting_analyzer - - return sorting_analyzer - - -@pytest.fixture(scope="module") -def sorting_analyzer_simple(): - return _sorting_analyzer_simple() diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index 446007d10b..156bab84d8 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -536,40 +536,3 @@ def test_calculate_sd_ratio(sorting_analyzer_simple): assert np.all(list(sd_ratio.keys()) == sorting_analyzer_simple.unit_ids) # @aurelien can you check this, this is not working anymore # assert np.allclose(list(sd_ratio.values()), 1, atol=0.25, rtol=0) - - -if __name__ == "__main__": - - test_unit_structure_in_output(small_sorting_analyzer) - test_unit_id_order_independence(small_sorting_analyzer) - - test_synchrony_counts_no_sync() - test_synchrony_counts_one_sync() - test_synchrony_counts_one_quad_sync() - test_synchrony_counts_not_all_units() - - test_mahalanobis_metrics() - test_lda_metrics() - test_nearest_neighbors_metrics() - test_silhouette_score_metrics() - test_simplified_silhouette_score_metrics() - - test_calculate_firing_rate_num_spikes(sorting_analyzer_simple) - test_calculate_snrs(sorting_analyzer) - test_calculate_amplitude_cutoff(sorting_analyzer) - test_calculate_presence_ratio(sorting_analyzer) - test_calculate_amplitude_median(sorting_analyzer) - test_calculate_sliding_rp_violations(sorting_analyzer) - test_calculate_drift_metrics(sorting_analyzer) - test_synchrony_metrics(sorting_analyzer) - test_synchrony_metrics_unit_id_subset(sorting_analyzer) - test_synchrony_metrics_no_unit_ids(sorting_analyzer) - test_calculate_firing_range(sorting_analyzer) - test_calculate_amplitude_cv_metrics(sorting_analyzer) - test_calculate_sd_ratio(sorting_analyzer) - - sorting_analyzer_violations = _sorting_analyzer_violations() - # print(sorting_analyzer_violations) - test_calculate_isi_violations(sorting_analyzer_violations) - test_calculate_sliding_rp_violations(sorting_analyzer_violations) - test_calculate_rp_violations(sorting_analyzer_violations) diff --git a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py index a0fc97c37c..6ddeb02689 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py @@ -22,8 +22,3 @@ def test_calculate_pc_metrics(small_sorting_analyzer): assert not np.all(np.isnan(res2[metric_name].values)) assert np.array_equal(res1[metric_name].values, res2[metric_name].values) - - -if __name__ == "__main__": - - test_calculate_pc_metrics(small_sorting_analyzer) diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py index 616e6c90c1..fec5ceeb95 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py @@ -245,9 +245,3 @@ def test_empty_units(sorting_analyzer_simple): # for metric_name in metrics.columns: # # NaNs are skipped # assert np.allclose(metrics[metric_name].dropna(), metrics_par[metric_name].dropna()) - -if __name__ == "__main__": - - test_compute_quality_metrics(sorting_analyzer_simple) - test_compute_quality_metrics_recordingless(sorting_analyzer_simple) - test_empty_units(sorting_analyzer_simple) From 5d5afd2873eb19791a7366be88ae577961f6755e Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Wed, 11 Sep 2024 10:45:12 +0100 Subject: [PATCH 6/9] Put main stuff back --- .../tests/test_metrics_functions.py | 29 +++++++++++++++++++ .../tests/test_quality_metric_calculator.py | 10 ++++++- 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index 156bab84d8..0a936edb39 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -536,3 +536,32 @@ def test_calculate_sd_ratio(sorting_analyzer_simple): assert np.all(list(sd_ratio.keys()) == sorting_analyzer_simple.unit_ids) # @aurelien can you check this, this is not working anymore # assert np.allclose(list(sd_ratio.values()), 1, atol=0.25, rtol=0) + + +if __name__ == "__main__": + + sorting_analyzer = _sorting_analyzer_simple() + print(sorting_analyzer) + + test_unit_structure_in_output(_small_sorting_analyzer()) + + # test_calculate_firing_rate_num_spikes(sorting_analyzer) + + test_calculate_snrs(sorting_analyzer) + # test_calculate_amplitude_cutoff(sorting_analyzer) + # test_calculate_presence_ratio(sorting_analyzer) + # test_calculate_amplitude_median(sorting_analyzer) + # test_calculate_sliding_rp_violations(sorting_analyzer) + # test_calculate_drift_metrics(sorting_analyzer) + # test_synchrony_metrics(sorting_analyzer) + # test_synchrony_metrics_unit_id_subset(sorting_analyzer) + # test_synchrony_metrics_no_unit_ids(sorting_analyzer) + # test_calculate_firing_range(sorting_analyzer) + # test_calculate_amplitude_cv_metrics(sorting_analyzer) + # test_calculate_sd_ratio(sorting_analyzer) + + # sorting_analyzer_violations = _sorting_analyzer_violations() + # print(sorting_analyzer_violations) + # test_calculate_isi_violations(sorting_analyzer_violations) + # test_calculate_sliding_rp_violations(sorting_analyzer_violations) + # test_calculate_rp_violations(sorting_analyzer_violations) diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py index fec5ceeb95..a6415c58e8 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py @@ -19,7 +19,6 @@ def test_compute_quality_metrics(sorting_analyzer_simple): sorting_analyzer = sorting_analyzer_simple - print(sorting_analyzer) # without PCs metrics = compute_quality_metrics( @@ -245,3 +244,12 @@ def test_empty_units(sorting_analyzer_simple): # for metric_name in metrics.columns: # # NaNs are skipped # assert np.allclose(metrics[metric_name].dropna(), metrics_par[metric_name].dropna()) + +if __name__ == "__main__": + + sorting_analyzer = get_sorting_analyzer() + print(sorting_analyzer) + + test_compute_quality_metrics(sorting_analyzer) + test_compute_quality_metrics_recordingless(sorting_analyzer) + test_empty_units(sorting_analyzer) From f3aac424c6da100aac0614b7f5382c0d797ebd17 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Wed, 11 Sep 2024 10:46:59 +0100 Subject: [PATCH 7/9] oups --- .../qualitymetrics/tests/test_metrics_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index 0a936edb39..e7fc7ce209 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -547,7 +547,7 @@ def test_calculate_sd_ratio(sorting_analyzer_simple): # test_calculate_firing_rate_num_spikes(sorting_analyzer) - test_calculate_snrs(sorting_analyzer) + # test_calculate_snrs(sorting_analyzer) # test_calculate_amplitude_cutoff(sorting_analyzer) # test_calculate_presence_ratio(sorting_analyzer) # test_calculate_amplitude_median(sorting_analyzer) From b35873cd5a3fd7e37cf8178b9699a74cd18c08f8 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Wed, 11 Sep 2024 20:08:17 +0100 Subject: [PATCH 8/9] compute pcs for sortinganalyzer again --- src/spikeinterface/qualitymetrics/tests/conftest.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/qualitymetrics/tests/conftest.py b/src/spikeinterface/qualitymetrics/tests/conftest.py index 01fa16c8d7..676889094b 100644 --- a/src/spikeinterface/qualitymetrics/tests/conftest.py +++ b/src/spikeinterface/qualitymetrics/tests/conftest.py @@ -67,5 +67,6 @@ def sorting_analyzer_simple(): sorting_analyzer.compute("waveforms", **job_kwargs) sorting_analyzer.compute("templates") sorting_analyzer.compute("spike_amplitudes", **job_kwargs) + sorting_analyzer.compute("principal_components", n_components=5, mode="by_channel_local", **job_kwargs) return sorting_analyzer From 41f73ed5c68992a871531a1d832e4179c2e9e02d Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Thu, 12 Sep 2024 09:50:19 +0100 Subject: [PATCH 9/9] re-remove pcs --- src/spikeinterface/qualitymetrics/tests/conftest.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/qualitymetrics/tests/conftest.py b/src/spikeinterface/qualitymetrics/tests/conftest.py index 676889094b..01fa16c8d7 100644 --- a/src/spikeinterface/qualitymetrics/tests/conftest.py +++ b/src/spikeinterface/qualitymetrics/tests/conftest.py @@ -67,6 +67,5 @@ def sorting_analyzer_simple(): sorting_analyzer.compute("waveforms", **job_kwargs) sorting_analyzer.compute("templates") sorting_analyzer.compute("spike_amplitudes", **job_kwargs) - sorting_analyzer.compute("principal_components", n_components=5, mode="by_channel_local", **job_kwargs) return sorting_analyzer