diff --git a/src/spikeinterface/qualitymetrics/tests/conftest.py b/src/spikeinterface/qualitymetrics/tests/conftest.py index bb2a345340..01fa16c8d7 100644 --- a/src/spikeinterface/qualitymetrics/tests/conftest.py +++ b/src/spikeinterface/qualitymetrics/tests/conftest.py @@ -5,8 +5,11 @@ create_sorting_analyzer, ) +job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") -def _small_sorting_analyzer(): + +@pytest.fixture(scope="module") +def small_sorting_analyzer(): recording, sorting = generate_ground_truth_recording( durations=[2.0], num_units=10, @@ -33,5 +36,36 @@ def _small_sorting_analyzer(): @pytest.fixture(scope="module") -def small_sorting_analyzer(): - return _small_sorting_analyzer() +def sorting_analyzer_simple(): + # we need high firing rate for amplitude_cutoff + recording, sorting = generate_ground_truth_recording( + durations=[ + 120.0, + ], + sampling_frequency=30_000.0, + num_channels=6, + num_units=10, + generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0), + generate_unit_locations_kwargs=dict( + margin_um=5.0, + minimum_z=5.0, + maximum_z=20.0, + ), + generate_templates_kwargs=dict( + unit_params=dict( + alpha=(200.0, 500.0), + ) + ), + noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"), + seed=1205, + ) + + sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=True) + + sorting_analyzer.compute("random_spikes", max_spikes_per_unit=300, seed=1205) + sorting_analyzer.compute("noise_levels") + sorting_analyzer.compute("waveforms", **job_kwargs) + sorting_analyzer.compute("templates") + sorting_analyzer.compute("spike_amplitudes", **job_kwargs) + + return sorting_analyzer diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index bb222200e9..e7fc7ce209 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -135,36 +135,6 @@ def test_unit_id_order_independence(small_sorting_analyzer): assert quality_metrics_2[metric][1] == metric_1_data["#4"] -def _sorting_analyzer_simple(): - recording, sorting = generate_ground_truth_recording( - durations=[ - 50.0, - ], - sampling_frequency=30_000.0, - num_channels=6, - num_units=10, - generate_sorting_kwargs=dict(firing_rates=6.0, refractory_period_ms=4.0), - noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"), - seed=2205, - ) - - sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=True) - - sorting_analyzer.compute("random_spikes", max_spikes_per_unit=300, seed=2205) - sorting_analyzer.compute("noise_levels") - sorting_analyzer.compute("waveforms", **job_kwargs) - sorting_analyzer.compute("templates") - sorting_analyzer.compute("principal_components", n_components=5, mode="by_channel_local", **job_kwargs) - sorting_analyzer.compute("spike_amplitudes", **job_kwargs) - - return sorting_analyzer - - -@pytest.fixture(scope="module") -def sorting_analyzer_simple(): - return _sorting_analyzer_simple() - - def _sorting_violation(): max_time = 100.0 sampling_frequency = 30000 @@ -576,6 +546,7 @@ def test_calculate_sd_ratio(sorting_analyzer_simple): test_unit_structure_in_output(_small_sorting_analyzer()) # test_calculate_firing_rate_num_spikes(sorting_analyzer) + # test_calculate_snrs(sorting_analyzer) # test_calculate_amplitude_cutoff(sorting_analyzer) # test_calculate_presence_ratio(sorting_analyzer) diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py index 28869ba5ff..a6415c58e8 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py @@ -2,7 +2,6 @@ from pathlib import Path import numpy as np - from spikeinterface.core import ( generate_ground_truth_recording, create_sorting_analyzer, @@ -15,54 +14,11 @@ compute_quality_metrics, ) - job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") -def get_sorting_analyzer(seed=2205): - # we need high firing rate for amplitude_cutoff - recording, sorting = generate_ground_truth_recording( - durations=[ - 120.0, - ], - sampling_frequency=30_000.0, - num_channels=6, - num_units=10, - generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0), - generate_unit_locations_kwargs=dict( - margin_um=5.0, - minimum_z=5.0, - maximum_z=20.0, - ), - generate_templates_kwargs=dict( - unit_params=dict( - alpha=(200.0, 500.0), - ) - ), - noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"), - seed=seed, - ) - - sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=True) - - sorting_analyzer.compute("random_spikes", max_spikes_per_unit=300, seed=seed) - sorting_analyzer.compute("noise_levels") - sorting_analyzer.compute("waveforms", **job_kwargs) - sorting_analyzer.compute("templates") - sorting_analyzer.compute("spike_amplitudes", **job_kwargs) - - return sorting_analyzer - - -@pytest.fixture(scope="module") -def sorting_analyzer_simple(): - sorting_analyzer = get_sorting_analyzer(seed=2205) - return sorting_analyzer - - def test_compute_quality_metrics(sorting_analyzer_simple): sorting_analyzer = sorting_analyzer_simple - print(sorting_analyzer) # without PCs metrics = compute_quality_metrics(