diff --git a/src/spikeinterface/qualitymetrics/pca_metrics.py b/src/spikeinterface/qualitymetrics/pca_metrics.py index 2915cee8ec..fa1940c2ba 100644 --- a/src/spikeinterface/qualitymetrics/pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/pca_metrics.py @@ -17,7 +17,6 @@ from ..core import get_random_data_chunks, compute_sparsity from ..core.template_tools import get_template_extremum_channel - _possible_pc_metric_names = [ "isolation_distance", "l_ratio", @@ -90,7 +89,7 @@ def compute_pc_metrics( sorting = sorting_analyzer.sorting if metric_names is None: - metric_names = _possible_pc_metric_names + metric_names = _possible_pc_metric_names.copy() if qm_params is None: qm_params = _default_params @@ -110,8 +109,13 @@ def compute_pc_metrics( if "nn_isolation" in metric_names: pc_metrics["nn_unit_id"] = {} + possible_nn_metrics = ["nn_isolation", "nn_noise_overlap"] + + nn_metrics = list(set(metric_names).intersection(possible_nn_metrics)) + non_nn_metrics = list(set(metric_names).difference(possible_nn_metrics)) + # Compute nspikes and firing rate outside of main loop for speed - if any([n in metric_names for n in ["nn_isolation", "nn_noise_overlap"]]): + if nn_metrics: n_spikes_all_units = compute_num_spikes(sorting_analyzer, unit_ids=unit_ids) fr_all_units = compute_firing_rates(sorting_analyzer, unit_ids=unit_ids) else: @@ -120,9 +124,6 @@ def compute_pc_metrics( run_in_parallel = n_jobs > 1 - if run_in_parallel: - parallel_functions = [] - # this get dense projection for selected unit_ids dense_projections, spike_unit_indices = pca_ext.get_some_projections(channel_ids=None, unit_ids=unit_ids) all_labels = sorting.unit_ids[spike_unit_indices] @@ -146,7 +147,7 @@ def compute_pc_metrics( func_args = ( pcs_flat, labels, - metric_names, + non_nn_metrics, unit_id, unit_ids, qm_params, @@ -156,16 +157,16 @@ def compute_pc_metrics( ) items.append(func_args) - if not run_in_parallel: + if not run_in_parallel and non_nn_metrics: units_loop = enumerate(unit_ids) if progress_bar: - units_loop = tqdm(units_loop, desc="calculate_pc_metrics", total=len(unit_ids)) + units_loop = tqdm(units_loop, desc="calculate pc_metrics", total=len(unit_ids)) for unit_ind, unit_id in units_loop: pca_metrics_unit = pca_metrics_one_unit(items[unit_ind]) for metric_name, metric in pca_metrics_unit.items(): pc_metrics[metric_name][unit_id] = metric - else: + elif run_in_parallel and non_nn_metrics: with ProcessPoolExecutor(n_jobs) as executor: results = executor.map(pca_metrics_one_unit, items) if progress_bar: @@ -176,6 +177,37 @@ def compute_pc_metrics( for metric_name, metric in pca_metrics_unit.items(): pc_metrics[metric_name][unit_id] = metric + for metric_name in nn_metrics: + units_loop = enumerate(unit_ids) + if progress_bar: + units_loop = tqdm(units_loop, desc=f"calculate {metric_name} metric", total=len(unit_ids)) + + func = _nn_metric_name_to_func[metric_name] + metric_params = qm_params[metric_name] if metric_name in qm_params else {} + + for _, unit_id in units_loop: + try: + res = func( + sorting_analyzer, + unit_id, + seed=seed, + n_spikes_all_units=n_spikes_all_units, + fr_all_units=fr_all_units, + **metric_params, + ) + except: + if metric_name == "nn_isolation": + res = (np.nan, np.nan) + elif metric_name == "nn_noise_overlap": + res = np.nan + + if metric_name == "nn_isolation": + nn_isolation, nn_unit_id = res + pc_metrics["nn_isolation"][unit_id] = nn_isolation + pc_metrics["nn_unit_id"][unit_id] = nn_unit_id + elif metric_name == "nn_noise_overlap": + pc_metrics["nn_noise_overlap"][unit_id] = res + return pc_metrics @@ -677,6 +709,14 @@ def nearest_neighbors_noise_overlap( templates_ext = sorting_analyzer.get_extension("templates") assert templates_ext is not None, "nearest_neighbors_isolation() need extension 'templates'" + try: + sorting_analyzer.get_extension("templates").get_data(operator="median") + except KeyError: + warnings.warn( + "nearest_neighbors_isolation() need extension 'templates' calculated with the 'median' operator." + "You can run sorting_analyzer.compute('templates', operators=['average', 'median']) to calculate templates based on both average and median modes." + ) + if n_spikes_all_units is None: n_spikes_all_units = compute_num_spikes(sorting_analyzer) if fr_all_units is None: @@ -955,11 +995,13 @@ def pca_metrics_one_unit(args): pc_metrics = {} # metrics if "isolation_distance" in metric_names or "l_ratio" in metric_names: + try: isolation_distance, l_ratio = mahalanobis_metrics(pcs_flat, labels, unit_id) except: isolation_distance = np.nan l_ratio = np.nan + if "isolation_distance" in metric_names: pc_metrics["isolation_distance"] = isolation_distance if "l_ratio" in metric_names: @@ -973,6 +1015,7 @@ def pca_metrics_one_unit(args): d_prime = lda_metrics(pcs_flat, labels, unit_id) except: d_prime = np.nan + pc_metrics["d_prime"] = d_prime if "nearest_neighbor" in metric_names: @@ -986,36 +1029,6 @@ def pca_metrics_one_unit(args): pc_metrics["nn_hit_rate"] = nn_hit_rate pc_metrics["nn_miss_rate"] = nn_miss_rate - if "nn_isolation" in metric_names: - try: - nn_isolation, nn_unit_id = nearest_neighbors_isolation( - we, - unit_id, - seed=seed, - n_spikes_all_units=n_spikes_all_units, - fr_all_units=fr_all_units, - **qm_params["nn_isolation"], - ) - except: - nn_isolation = np.nan - nn_unit_id = np.nan - pc_metrics["nn_isolation"] = nn_isolation - pc_metrics["nn_unit_id"] = nn_unit_id - - if "nn_noise_overlap" in metric_names: - try: - nn_noise_overlap = nearest_neighbors_noise_overlap( - we, - unit_id, - n_spikes_all_units=n_spikes_all_units, - fr_all_units=fr_all_units, - seed=seed, - **qm_params["nn_noise_overlap"], - ) - except: - nn_noise_overlap = np.nan - pc_metrics["nn_noise_overlap"] = nn_noise_overlap - if "silhouette" in metric_names: silhouette_method = qm_params["silhouette"]["method"] if "simplified" in silhouette_method: @@ -1032,3 +1045,9 @@ def pca_metrics_one_unit(args): pc_metrics["silhouette_full"] = unit_silhouette_score return pc_metrics + + +_nn_metric_name_to_func = { + "nn_isolation": nearest_neighbors_isolation, + "nn_noise_overlap": nearest_neighbors_noise_overlap, +} diff --git a/src/spikeinterface/qualitymetrics/tests/conftest.py b/src/spikeinterface/qualitymetrics/tests/conftest.py new file mode 100644 index 0000000000..bb2a345340 --- /dev/null +++ b/src/spikeinterface/qualitymetrics/tests/conftest.py @@ -0,0 +1,37 @@ +import pytest + +from spikeinterface.core import ( + generate_ground_truth_recording, + create_sorting_analyzer, +) + + +def _small_sorting_analyzer(): + recording, sorting = generate_ground_truth_recording( + durations=[2.0], + num_units=10, + seed=1205, + ) + + sorting = sorting.select_units([2, 7, 0], ["#3", "#9", "#4"]) + + sorting_analyzer = create_sorting_analyzer(recording=recording, sorting=sorting, format="memory") + + extensions_to_compute = { + "random_spikes": {"seed": 1205}, + "noise_levels": {"seed": 1205}, + "waveforms": {}, + "templates": {"operators": ["average", "median"]}, + "spike_amplitudes": {}, + "spike_locations": {}, + "principal_components": {}, + } + + sorting_analyzer.compute(extensions_to_compute) + + return sorting_analyzer + + +@pytest.fixture(scope="module") +def small_sorting_analyzer(): + return _small_sorting_analyzer() diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index aec8201f44..90b622b9ab 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -47,37 +47,6 @@ job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") -def _small_sorting_analyzer(): - recording, sorting = generate_ground_truth_recording( - durations=[2.0], - num_units=4, - seed=1205, - ) - - sorting = sorting.select_units([3, 2, 0], ["#3", "#9", "#4"]) - - sorting_analyzer = create_sorting_analyzer(recording=recording, sorting=sorting, format="memory") - - extensions_to_compute = { - "random_spikes": {"seed": 1205}, - "noise_levels": {"seed": 1205}, - "waveforms": {}, - "templates": {}, - "spike_amplitudes": {}, - "spike_locations": {}, - "principal_components": {}, - } - - sorting_analyzer.compute(extensions_to_compute) - - return sorting_analyzer - - -@pytest.fixture(scope="module") -def small_sorting_analyzer(): - return _small_sorting_analyzer() - - def test_unit_structure_in_output(small_sorting_analyzer): qm_params = { @@ -126,7 +95,7 @@ def test_unit_id_order_independence(small_sorting_analyzer): """ recording = small_sorting_analyzer.recording - sorting = small_sorting_analyzer.sorting.select_units(["#4", "#9", "#3"], [0, 2, 3]) + sorting = small_sorting_analyzer.sorting.select_units(["#4", "#9", "#3"], [1, 7, 2]) small_sorting_analyzer_2 = create_sorting_analyzer(recording=recording, sorting=sorting, format="memory") @@ -161,9 +130,9 @@ def test_unit_id_order_independence(small_sorting_analyzer): ) for metric, metric_1_data in quality_metrics_1.items(): - assert quality_metrics_2[metric][3] == metric_1_data["#3"] - assert quality_metrics_2[metric][2] == metric_1_data["#9"] - assert quality_metrics_2[metric][0] == metric_1_data["#4"] + assert quality_metrics_2[metric][2] == metric_1_data["#3"] + assert quality_metrics_2[metric][7] == metric_1_data["#9"] + assert quality_metrics_2[metric][1] == metric_1_data["#4"] def _sorting_analyzer_simple(): diff --git a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py index 4e5a4858bb..6ddeb02689 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py @@ -1,82 +1,24 @@ import pytest -from pathlib import Path import numpy as np -from spikeinterface.core import ( - generate_ground_truth_recording, - create_sorting_analyzer, -) - from spikeinterface.qualitymetrics import ( compute_pc_metrics, - nearest_neighbors_isolation, - nearest_neighbors_noise_overlap, ) -job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") - - -def _sorting_analyzer_simple(): - recording, sorting = generate_ground_truth_recording( - durations=[ - 50.0, - ], - sampling_frequency=30_000.0, - num_channels=6, - num_units=10, - generate_sorting_kwargs=dict(firing_rates=6.0, refractory_period_ms=4.0), - noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"), - seed=2205, - ) - - sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=True) - - sorting_analyzer.compute("random_spikes", max_spikes_per_unit=300, seed=2205) - sorting_analyzer.compute("noise_levels") - sorting_analyzer.compute("waveforms", **job_kwargs) - sorting_analyzer.compute("templates", operators=["average", "std", "median"]) - sorting_analyzer.compute("principal_components", n_components=5, mode="by_channel_local", **job_kwargs) - sorting_analyzer.compute("spike_amplitudes", **job_kwargs) - - return sorting_analyzer - - -@pytest.fixture(scope="module") -def sorting_analyzer_simple(): - return _sorting_analyzer_simple() - - -def test_calculate_pc_metrics(sorting_analyzer_simple): +def test_calculate_pc_metrics(small_sorting_analyzer): import pandas as pd - sorting_analyzer = sorting_analyzer_simple - res1 = compute_pc_metrics(sorting_analyzer, n_jobs=1, progress_bar=True) + sorting_analyzer = small_sorting_analyzer + res1 = compute_pc_metrics(sorting_analyzer, n_jobs=1, progress_bar=True, seed=1205) res1 = pd.DataFrame(res1) - res2 = compute_pc_metrics(sorting_analyzer, n_jobs=2, progress_bar=True) + res2 = compute_pc_metrics(sorting_analyzer, n_jobs=2, progress_bar=True, seed=1205) res2 = pd.DataFrame(res2) - for k in res1.columns: - mask = ~np.isnan(res1[k].values) - if np.any(mask): - assert np.array_equal(res1[k].values[mask], res2[k].values[mask]) - - -def test_nearest_neighbors_isolation(sorting_analyzer_simple): - sorting_analyzer = sorting_analyzer_simple - this_unit_id = sorting_analyzer.unit_ids[0] - nearest_neighbors_isolation(sorting_analyzer, this_unit_id) - - -def test_nearest_neighbors_noise_overlap(sorting_analyzer_simple): - sorting_analyzer = sorting_analyzer_simple - this_unit_id = sorting_analyzer.unit_ids[0] - nearest_neighbors_noise_overlap(sorting_analyzer, this_unit_id) - + for metric_name in res1.columns: + if metric_name != "nn_unit_id": + assert not np.all(np.isnan(res1[metric_name].values)) + assert not np.all(np.isnan(res2[metric_name].values)) -if __name__ == "__main__": - sorting_analyzer = _sorting_analyzer_simple() - test_calculate_pc_metrics(sorting_analyzer) - test_nearest_neighbors_isolation(sorting_analyzer) - test_nearest_neighbors_noise_overlap(sorting_analyzer) + assert np.array_equal(res1[metric_name].values, res2[metric_name].values)