From 0a49ebcec7435f25574d542dbcc198b8e5b4c12f Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Thu, 4 Jul 2024 10:25:53 +0200 Subject: [PATCH 1/5] Fix nn pca_metric computation and update tests --- .../qualitymetrics/pca_metrics.py | 137 ++++++++++-------- .../tests/test_metrics_functions.py | 14 +- .../qualitymetrics/tests/test_pca_metrics.py | 14 +- 3 files changed, 90 insertions(+), 75 deletions(-) diff --git a/src/spikeinterface/qualitymetrics/pca_metrics.py b/src/spikeinterface/qualitymetrics/pca_metrics.py index 2915cee8ec..f270627b43 100644 --- a/src/spikeinterface/qualitymetrics/pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/pca_metrics.py @@ -90,7 +90,7 @@ def compute_pc_metrics( sorting = sorting_analyzer.sorting if metric_names is None: - metric_names = _possible_pc_metric_names + metric_names = _possible_pc_metric_names.copy() if qm_params is None: qm_params = _default_params @@ -110,8 +110,15 @@ def compute_pc_metrics( if "nn_isolation" in metric_names: pc_metrics["nn_unit_id"] = {} + possible_nn_metrics = ["nn_isolation", "nn_noise_overlap"] + nn_metrics = [] + for possible_nn_metric in possible_nn_metrics: + if possible_nn_metric in metric_names: + metric_names.remove(possible_nn_metric) + nn_metrics.append(possible_nn_metric) + # Compute nspikes and firing rate outside of main loop for speed - if any([n in metric_names for n in ["nn_isolation", "nn_noise_overlap"]]): + if nn_metrics: n_spikes_all_units = compute_num_spikes(sorting_analyzer, unit_ids=unit_ids) fr_all_units = compute_firing_rates(sorting_analyzer, unit_ids=unit_ids) else: @@ -120,9 +127,6 @@ def compute_pc_metrics( run_in_parallel = n_jobs > 1 - if run_in_parallel: - parallel_functions = [] - # this get dense projection for selected unit_ids dense_projections, spike_unit_indices = pca_ext.get_some_projections(channel_ids=None, unit_ids=unit_ids) all_labels = sorting.unit_ids[spike_unit_indices] @@ -156,16 +160,16 @@ def compute_pc_metrics( ) items.append(func_args) - if not run_in_parallel: + if not run_in_parallel and metric_names: units_loop = enumerate(unit_ids) if progress_bar: - units_loop = tqdm(units_loop, desc="calculate_pc_metrics", total=len(unit_ids)) + units_loop = tqdm(units_loop, desc="calculate non nn pc_metrics", total=len(unit_ids)) for unit_ind, unit_id in units_loop: pca_metrics_unit = pca_metrics_one_unit(items[unit_ind]) for metric_name, metric in pca_metrics_unit.items(): pc_metrics[metric_name][unit_id] = metric - else: + elif run_in_parallel and metric_names: with ProcessPoolExecutor(n_jobs) as executor: results = executor.map(pca_metrics_one_unit, items) if progress_bar: @@ -176,6 +180,44 @@ def compute_pc_metrics( for metric_name, metric in pca_metrics_unit.items(): pc_metrics[metric_name][unit_id] = metric + if "nn_isolation" in nn_metrics: + units_loop = enumerate(unit_ids) + if progress_bar: + units_loop = tqdm(units_loop, desc="calculate nn_isolation metric", total=len(unit_ids)) + + for unit_ind, unit_id in units_loop: + nn_isolation, nn_unit_id = nearest_neighbors_isolation( + sorting_analyzer, + unit_id, + seed=seed, + n_spikes_all_units=n_spikes_all_units, + fr_all_units=fr_all_units, + **qm_params["nn_isolation"], + ) + + pc_metrics["nn_isolation"][unit_id] = nn_isolation + pc_metrics["nn_unit_id"][unit_id] = nn_unit_id + + if "nn_noise_overlap" in nn_metrics: + units_loop = enumerate(unit_ids) + if progress_bar: + units_loop = tqdm( + units_loop, + desc="calculate nn_noise_overlap metric", + total=len(unit_ids), + ) + + for unit_ind, unit_id in units_loop: + nn_noise_overlap = nearest_neighbors_noise_overlap( + sorting_analyzer, + unit_id, + n_spikes_all_units=n_spikes_all_units, + fr_all_units=fr_all_units, + seed=2205, + **qm_params["nn_noise_overlap"], + ) + pc_metrics["nn_noise_overlap"][unit_id] = nn_noise_overlap + return pc_metrics @@ -677,6 +719,14 @@ def nearest_neighbors_noise_overlap( templates_ext = sorting_analyzer.get_extension("templates") assert templates_ext is not None, "nearest_neighbors_isolation() need extension 'templates'" + try: + sorting_analyzer.get_extension("templates").get_data(operator="median") + except KeyError: + warnings.warn( + "nearest_neighbors_isolation() need extension 'templates' calculated with the 'median' operator." + "You can run sorting_analyzer.compute('templates', operators=['average', 'median']) to calculate templates based on both average and median modes." + ) + if n_spikes_all_units is None: n_spikes_all_units = compute_num_spikes(sorting_analyzer) if fr_all_units is None: @@ -955,11 +1005,9 @@ def pca_metrics_one_unit(args): pc_metrics = {} # metrics if "isolation_distance" in metric_names or "l_ratio" in metric_names: - try: - isolation_distance, l_ratio = mahalanobis_metrics(pcs_flat, labels, unit_id) - except: - isolation_distance = np.nan - l_ratio = np.nan + + isolation_distance, l_ratio = mahalanobis_metrics(pcs_flat, labels, unit_id) + if "isolation_distance" in metric_names: pc_metrics["isolation_distance"] = isolation_distance if "l_ratio" in metric_names: @@ -969,66 +1017,31 @@ def pca_metrics_one_unit(args): if len(unit_ids) == 1: d_prime = np.nan else: - try: - d_prime = lda_metrics(pcs_flat, labels, unit_id) - except: - d_prime = np.nan + + d_prime = lda_metrics(pcs_flat, labels, unit_id) + pc_metrics["d_prime"] = d_prime if "nearest_neighbor" in metric_names: - try: - nn_hit_rate, nn_miss_rate = nearest_neighbors_metrics( - pcs_flat, labels, unit_id, **qm_params["nearest_neighbor"] - ) - except: - nn_hit_rate = np.nan - nn_miss_rate = np.nan - pc_metrics["nn_hit_rate"] = nn_hit_rate - pc_metrics["nn_miss_rate"] = nn_miss_rate - if "nn_isolation" in metric_names: - try: - nn_isolation, nn_unit_id = nearest_neighbors_isolation( - we, - unit_id, - seed=seed, - n_spikes_all_units=n_spikes_all_units, - fr_all_units=fr_all_units, - **qm_params["nn_isolation"], - ) - except: - nn_isolation = np.nan - nn_unit_id = np.nan - pc_metrics["nn_isolation"] = nn_isolation - pc_metrics["nn_unit_id"] = nn_unit_id + nn_hit_rate, nn_miss_rate = nearest_neighbors_metrics( + pcs_flat, labels, unit_id, **qm_params["nearest_neighbor"] + ) - if "nn_noise_overlap" in metric_names: - try: - nn_noise_overlap = nearest_neighbors_noise_overlap( - we, - unit_id, - n_spikes_all_units=n_spikes_all_units, - fr_all_units=fr_all_units, - seed=seed, - **qm_params["nn_noise_overlap"], - ) - except: - nn_noise_overlap = np.nan - pc_metrics["nn_noise_overlap"] = nn_noise_overlap + pc_metrics["nn_hit_rate"] = nn_hit_rate + pc_metrics["nn_miss_rate"] = nn_miss_rate if "silhouette" in metric_names: silhouette_method = qm_params["silhouette"]["method"] if "simplified" in silhouette_method: - try: - unit_silhouette_score = simplified_silhouette_score(pcs_flat, labels, unit_id) - except: - unit_silhouette_score = np.nan + + unit_silhouette_score = simplified_silhouette_score(pcs_flat, labels, unit_id) + pc_metrics["silhouette"] = unit_silhouette_score if "full" in silhouette_method: - try: - unit_silhouette_score = silhouette_score(pcs_flat, labels, unit_id) - except: - unit_silhouette_score = np.nan + + unit_silhouette_score = silhouette_score(pcs_flat, labels, unit_id) + pc_metrics["silhouette_full"] = unit_silhouette_score return pc_metrics diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index aec8201f44..1f9775589d 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -50,11 +50,11 @@ def _small_sorting_analyzer(): recording, sorting = generate_ground_truth_recording( durations=[2.0], - num_units=4, + num_units=10, seed=1205, ) - sorting = sorting.select_units([3, 2, 0], ["#3", "#9", "#4"]) + sorting = sorting.select_units([2, 7, 0], ["#3", "#9", "#4"]) sorting_analyzer = create_sorting_analyzer(recording=recording, sorting=sorting, format="memory") @@ -62,7 +62,7 @@ def _small_sorting_analyzer(): "random_spikes": {"seed": 1205}, "noise_levels": {"seed": 1205}, "waveforms": {}, - "templates": {}, + "templates": {"operators": ["average", "median"]}, "spike_amplitudes": {}, "spike_locations": {}, "principal_components": {}, @@ -126,7 +126,7 @@ def test_unit_id_order_independence(small_sorting_analyzer): """ recording = small_sorting_analyzer.recording - sorting = small_sorting_analyzer.sorting.select_units(["#4", "#9", "#3"], [0, 2, 3]) + sorting = small_sorting_analyzer.sorting.select_units(["#4", "#9", "#3"], [1, 7, 2]) small_sorting_analyzer_2 = create_sorting_analyzer(recording=recording, sorting=sorting, format="memory") @@ -161,9 +161,9 @@ def test_unit_id_order_independence(small_sorting_analyzer): ) for metric, metric_1_data in quality_metrics_1.items(): - assert quality_metrics_2[metric][3] == metric_1_data["#3"] - assert quality_metrics_2[metric][2] == metric_1_data["#9"] - assert quality_metrics_2[metric][0] == metric_1_data["#4"] + assert quality_metrics_2[metric][2] == metric_1_data["#3"] + assert quality_metrics_2[metric][7] == metric_1_data["#9"] + assert quality_metrics_2[metric][1] == metric_1_data["#4"] def _sorting_analyzer_simple(): diff --git a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py index 4e5a4858bb..aa6c54cfa0 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py @@ -13,7 +13,6 @@ nearest_neighbors_noise_overlap, ) - job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") @@ -47,18 +46,21 @@ def sorting_analyzer_simple(): return _sorting_analyzer_simple() -def test_calculate_pc_metrics(sorting_analyzer_simple): +def test_calculate_pc_metrics(small_sorting_analyzer): import pandas as pd - sorting_analyzer = sorting_analyzer_simple - res1 = compute_pc_metrics(sorting_analyzer, n_jobs=1, progress_bar=True) + sorting_analyzer = small_sorting_analyzer + res1 = compute_pc_metrics(sorting_analyzer, n_jobs=1, progress_bar=True, seed=1205) res1 = pd.DataFrame(res1) - res2 = compute_pc_metrics(sorting_analyzer, n_jobs=2, progress_bar=True) + res2 = compute_pc_metrics(sorting_analyzer, n_jobs=2, progress_bar=True, seed=1205) res2 = pd.DataFrame(res2) for k in res1.columns: - mask = ~np.isnan(res1[k].values) + if k == "nn_unit_id": + mask = [True, True, True] + else: + mask = ~np.isnan(res1[k].values) if np.any(mask): assert np.array_equal(res1[k].values[mask], res2[k].values[mask]) From 072c36c18df0848e969e09b67ac5e4b226757f73 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Tue, 9 Jul 2024 10:34:57 +0200 Subject: [PATCH 2/5] Respond to review and fix fixture in testing --- .../qualitymetrics/pca_metrics.py | 87 ++++++++++--------- .../qualitymetrics/tests/conftest.py | 37 ++++++++ .../tests/test_metrics_functions.py | 31 ------- .../qualitymetrics/tests/test_pca_metrics.py | 70 ++------------- 4 files changed, 87 insertions(+), 138 deletions(-) create mode 100644 src/spikeinterface/qualitymetrics/tests/conftest.py diff --git a/src/spikeinterface/qualitymetrics/pca_metrics.py b/src/spikeinterface/qualitymetrics/pca_metrics.py index f270627b43..f34adf5ffb 100644 --- a/src/spikeinterface/qualitymetrics/pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/pca_metrics.py @@ -17,7 +17,6 @@ from ..core import get_random_data_chunks, compute_sparsity from ..core.template_tools import get_template_extremum_channel - _possible_pc_metric_names = [ "isolation_distance", "l_ratio", @@ -163,7 +162,7 @@ def compute_pc_metrics( if not run_in_parallel and metric_names: units_loop = enumerate(unit_ids) if progress_bar: - units_loop = tqdm(units_loop, desc="calculate non nn pc_metrics", total=len(unit_ids)) + units_loop = tqdm(units_loop, desc="calculate pc_metrics", total=len(unit_ids)) for unit_ind, unit_id in units_loop: pca_metrics_unit = pca_metrics_one_unit(items[unit_ind]) @@ -180,43 +179,31 @@ def compute_pc_metrics( for metric_name, metric in pca_metrics_unit.items(): pc_metrics[metric_name][unit_id] = metric - if "nn_isolation" in nn_metrics: + for metric_name in nn_metrics: units_loop = enumerate(unit_ids) if progress_bar: - units_loop = tqdm(units_loop, desc="calculate nn_isolation metric", total=len(unit_ids)) + units_loop = tqdm(units_loop, desc=f"calculate {metric_name} metric", total=len(unit_ids)) + + func = _nn_metric_name_to_func[metric_name] + metric_params = qm_params[metric_name] if metric_name in qm_params else {} for unit_ind, unit_id in units_loop: - nn_isolation, nn_unit_id = nearest_neighbors_isolation( + + res = func( sorting_analyzer, unit_id, seed=seed, n_spikes_all_units=n_spikes_all_units, fr_all_units=fr_all_units, - **qm_params["nn_isolation"], + **metric_params, ) - pc_metrics["nn_isolation"][unit_id] = nn_isolation - pc_metrics["nn_unit_id"][unit_id] = nn_unit_id - - if "nn_noise_overlap" in nn_metrics: - units_loop = enumerate(unit_ids) - if progress_bar: - units_loop = tqdm( - units_loop, - desc="calculate nn_noise_overlap metric", - total=len(unit_ids), - ) - - for unit_ind, unit_id in units_loop: - nn_noise_overlap = nearest_neighbors_noise_overlap( - sorting_analyzer, - unit_id, - n_spikes_all_units=n_spikes_all_units, - fr_all_units=fr_all_units, - seed=2205, - **qm_params["nn_noise_overlap"], - ) - pc_metrics["nn_noise_overlap"][unit_id] = nn_noise_overlap + if metric_name == "nn_isolation": + nn_isolation, nn_unit_id = res + pc_metrics["nn_isolation"][unit_id] = nn_isolation + pc_metrics["nn_unit_id"][unit_id] = nn_unit_id + elif metric_name == "nn_noise_overlap": + pc_metrics["nn_noise_overlap"][unit_id] = res return pc_metrics @@ -1006,7 +993,11 @@ def pca_metrics_one_unit(args): # metrics if "isolation_distance" in metric_names or "l_ratio" in metric_names: - isolation_distance, l_ratio = mahalanobis_metrics(pcs_flat, labels, unit_id) + try: + isolation_distance, l_ratio = mahalanobis_metrics(pcs_flat, labels, unit_id) + except: + isolation_distance = np.nan + l_ratio = np.nan if "isolation_distance" in metric_names: pc_metrics["isolation_distance"] = isolation_distance @@ -1017,31 +1008,43 @@ def pca_metrics_one_unit(args): if len(unit_ids) == 1: d_prime = np.nan else: - - d_prime = lda_metrics(pcs_flat, labels, unit_id) + try: + d_prime = lda_metrics(pcs_flat, labels, unit_id) + except: + d_prime = np.nan pc_metrics["d_prime"] = d_prime if "nearest_neighbor" in metric_names: - - nn_hit_rate, nn_miss_rate = nearest_neighbors_metrics( - pcs_flat, labels, unit_id, **qm_params["nearest_neighbor"] - ) - + try: + nn_hit_rate, nn_miss_rate = nearest_neighbors_metrics( + pcs_flat, labels, unit_id, **qm_params["nearest_neighbor"] + ) + except: + nn_hit_rate = np.nan + nn_miss_rate = np.nan pc_metrics["nn_hit_rate"] = nn_hit_rate pc_metrics["nn_miss_rate"] = nn_miss_rate if "silhouette" in metric_names: silhouette_method = qm_params["silhouette"]["method"] if "simplified" in silhouette_method: - - unit_silhouette_score = simplified_silhouette_score(pcs_flat, labels, unit_id) - + try: + unit_silhouette_score = simplified_silhouette_score(pcs_flat, labels, unit_id) + except: + unit_silhouette_score = np.nan pc_metrics["silhouette"] = unit_silhouette_score if "full" in silhouette_method: - - unit_silhouette_score = silhouette_score(pcs_flat, labels, unit_id) - + try: + unit_silhouette_score = silhouette_score(pcs_flat, labels, unit_id) + except: + unit_silhouette_score = np.nan pc_metrics["silhouette_full"] = unit_silhouette_score return pc_metrics + + +_nn_metric_name_to_func = { + "nn_isolation": nearest_neighbors_isolation, + "nn_noise_overlap": nearest_neighbors_noise_overlap, +} diff --git a/src/spikeinterface/qualitymetrics/tests/conftest.py b/src/spikeinterface/qualitymetrics/tests/conftest.py new file mode 100644 index 0000000000..bb2a345340 --- /dev/null +++ b/src/spikeinterface/qualitymetrics/tests/conftest.py @@ -0,0 +1,37 @@ +import pytest + +from spikeinterface.core import ( + generate_ground_truth_recording, + create_sorting_analyzer, +) + + +def _small_sorting_analyzer(): + recording, sorting = generate_ground_truth_recording( + durations=[2.0], + num_units=10, + seed=1205, + ) + + sorting = sorting.select_units([2, 7, 0], ["#3", "#9", "#4"]) + + sorting_analyzer = create_sorting_analyzer(recording=recording, sorting=sorting, format="memory") + + extensions_to_compute = { + "random_spikes": {"seed": 1205}, + "noise_levels": {"seed": 1205}, + "waveforms": {}, + "templates": {"operators": ["average", "median"]}, + "spike_amplitudes": {}, + "spike_locations": {}, + "principal_components": {}, + } + + sorting_analyzer.compute(extensions_to_compute) + + return sorting_analyzer + + +@pytest.fixture(scope="module") +def small_sorting_analyzer(): + return _small_sorting_analyzer() diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index 1f9775589d..90b622b9ab 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -47,37 +47,6 @@ job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") -def _small_sorting_analyzer(): - recording, sorting = generate_ground_truth_recording( - durations=[2.0], - num_units=10, - seed=1205, - ) - - sorting = sorting.select_units([2, 7, 0], ["#3", "#9", "#4"]) - - sorting_analyzer = create_sorting_analyzer(recording=recording, sorting=sorting, format="memory") - - extensions_to_compute = { - "random_spikes": {"seed": 1205}, - "noise_levels": {"seed": 1205}, - "waveforms": {}, - "templates": {"operators": ["average", "median"]}, - "spike_amplitudes": {}, - "spike_locations": {}, - "principal_components": {}, - } - - sorting_analyzer.compute(extensions_to_compute) - - return sorting_analyzer - - -@pytest.fixture(scope="module") -def small_sorting_analyzer(): - return _small_sorting_analyzer() - - def test_unit_structure_in_output(small_sorting_analyzer): qm_params = { diff --git a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py index aa6c54cfa0..6ddeb02689 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py @@ -1,50 +1,10 @@ import pytest -from pathlib import Path import numpy as np -from spikeinterface.core import ( - generate_ground_truth_recording, - create_sorting_analyzer, -) - from spikeinterface.qualitymetrics import ( compute_pc_metrics, - nearest_neighbors_isolation, - nearest_neighbors_noise_overlap, ) -job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") - - -def _sorting_analyzer_simple(): - recording, sorting = generate_ground_truth_recording( - durations=[ - 50.0, - ], - sampling_frequency=30_000.0, - num_channels=6, - num_units=10, - generate_sorting_kwargs=dict(firing_rates=6.0, refractory_period_ms=4.0), - noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"), - seed=2205, - ) - - sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=True) - - sorting_analyzer.compute("random_spikes", max_spikes_per_unit=300, seed=2205) - sorting_analyzer.compute("noise_levels") - sorting_analyzer.compute("waveforms", **job_kwargs) - sorting_analyzer.compute("templates", operators=["average", "std", "median"]) - sorting_analyzer.compute("principal_components", n_components=5, mode="by_channel_local", **job_kwargs) - sorting_analyzer.compute("spike_amplitudes", **job_kwargs) - - return sorting_analyzer - - -@pytest.fixture(scope="module") -def sorting_analyzer_simple(): - return _sorting_analyzer_simple() - def test_calculate_pc_metrics(small_sorting_analyzer): import pandas as pd @@ -56,29 +16,9 @@ def test_calculate_pc_metrics(small_sorting_analyzer): res2 = compute_pc_metrics(sorting_analyzer, n_jobs=2, progress_bar=True, seed=1205) res2 = pd.DataFrame(res2) - for k in res1.columns: - if k == "nn_unit_id": - mask = [True, True, True] - else: - mask = ~np.isnan(res1[k].values) - if np.any(mask): - assert np.array_equal(res1[k].values[mask], res2[k].values[mask]) - - -def test_nearest_neighbors_isolation(sorting_analyzer_simple): - sorting_analyzer = sorting_analyzer_simple - this_unit_id = sorting_analyzer.unit_ids[0] - nearest_neighbors_isolation(sorting_analyzer, this_unit_id) - - -def test_nearest_neighbors_noise_overlap(sorting_analyzer_simple): - sorting_analyzer = sorting_analyzer_simple - this_unit_id = sorting_analyzer.unit_ids[0] - nearest_neighbors_noise_overlap(sorting_analyzer, this_unit_id) - + for metric_name in res1.columns: + if metric_name != "nn_unit_id": + assert not np.all(np.isnan(res1[metric_name].values)) + assert not np.all(np.isnan(res2[metric_name].values)) -if __name__ == "__main__": - sorting_analyzer = _sorting_analyzer_simple() - test_calculate_pc_metrics(sorting_analyzer) - test_nearest_neighbors_isolation(sorting_analyzer) - test_nearest_neighbors_noise_overlap(sorting_analyzer) + assert np.array_equal(res1[metric_name].values, res2[metric_name].values) From 1e545233c68da2e1139dc624bb097977fc2e87f5 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Tue, 9 Jul 2024 12:36:45 +0200 Subject: [PATCH 3/5] Add try except to nn metrics --- .../qualitymetrics/pca_metrics.py | 23 +++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/src/spikeinterface/qualitymetrics/pca_metrics.py b/src/spikeinterface/qualitymetrics/pca_metrics.py index f34adf5ffb..8773950449 100644 --- a/src/spikeinterface/qualitymetrics/pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/pca_metrics.py @@ -188,15 +188,20 @@ def compute_pc_metrics( metric_params = qm_params[metric_name] if metric_name in qm_params else {} for unit_ind, unit_id in units_loop: - - res = func( - sorting_analyzer, - unit_id, - seed=seed, - n_spikes_all_units=n_spikes_all_units, - fr_all_units=fr_all_units, - **metric_params, - ) + try: + res = func( + sorting_analyzer, + unit_id, + seed=seed, + n_spikes_all_units=n_spikes_all_units, + fr_all_units=fr_all_units, + **metric_params, + ) + except: + if metric_name == "nn_isolation": + res = (np.nan, np.nan) + elif metric_name == "nn_noise_overlap": + res = np.nan if metric_name == "nn_isolation": nn_isolation, nn_unit_id = res From e36eaccb33842945d9bb11037b36fc385a2e90d8 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Fri, 12 Jul 2024 10:00:24 +0100 Subject: [PATCH 4/5] Split metric_names in nn_metrics and non_nn_metrics --- src/spikeinterface/qualitymetrics/pca_metrics.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/qualitymetrics/pca_metrics.py b/src/spikeinterface/qualitymetrics/pca_metrics.py index 8773950449..0686c16d57 100644 --- a/src/spikeinterface/qualitymetrics/pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/pca_metrics.py @@ -110,11 +110,9 @@ def compute_pc_metrics( pc_metrics["nn_unit_id"] = {} possible_nn_metrics = ["nn_isolation", "nn_noise_overlap"] - nn_metrics = [] - for possible_nn_metric in possible_nn_metrics: - if possible_nn_metric in metric_names: - metric_names.remove(possible_nn_metric) - nn_metrics.append(possible_nn_metric) + + nn_metrics = list(set(metric_names).intersection(possible_nn_metrics)) + non_nn_metrics = list(set(metric_names).difference(possible_nn_metrics)) # Compute nspikes and firing rate outside of main loop for speed if nn_metrics: @@ -149,7 +147,7 @@ def compute_pc_metrics( func_args = ( pcs_flat, labels, - metric_names, + non_nn_metrics, unit_id, unit_ids, qm_params, @@ -159,7 +157,7 @@ def compute_pc_metrics( ) items.append(func_args) - if not run_in_parallel and metric_names: + if not run_in_parallel and non_nn_metrics: units_loop = enumerate(unit_ids) if progress_bar: units_loop = tqdm(units_loop, desc="calculate pc_metrics", total=len(unit_ids)) @@ -168,7 +166,7 @@ def compute_pc_metrics( pca_metrics_unit = pca_metrics_one_unit(items[unit_ind]) for metric_name, metric in pca_metrics_unit.items(): pc_metrics[metric_name][unit_id] = metric - elif run_in_parallel and metric_names: + elif run_in_parallel and non_nn_metrics: with ProcessPoolExecutor(n_jobs) as executor: results = executor.map(pca_metrics_one_unit, items) if progress_bar: From 135b0c36aaf09a4fa900f56fc5af74e518d58b99 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Mon, 15 Jul 2024 10:32:26 +0100 Subject: [PATCH 5/5] remove unit_ind --- src/spikeinterface/qualitymetrics/pca_metrics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/qualitymetrics/pca_metrics.py b/src/spikeinterface/qualitymetrics/pca_metrics.py index 0686c16d57..fa1940c2ba 100644 --- a/src/spikeinterface/qualitymetrics/pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/pca_metrics.py @@ -185,7 +185,7 @@ def compute_pc_metrics( func = _nn_metric_name_to_func[metric_name] metric_params = qm_params[metric_name] if metric_name in qm_params else {} - for unit_ind, unit_id in units_loop: + for _, unit_id in units_loop: try: res = func( sorting_analyzer,