From a9a511af00aa7b6dbda0739f10dfa02dd7a11cc1 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Wed, 7 Aug 2024 16:39:54 +0100 Subject: [PATCH 01/17] Do not delete quality metrics on recompute --- .../quality_metric_calculator.py | 27 +++++++-- .../tests/test_metrics_functions.py | 59 +++++++++++++++++++ 2 files changed, 81 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index 0c7cf25237..25b8cc7c05 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -2,10 +2,11 @@ from __future__ import annotations - +import weakref import warnings from copy import deepcopy +import pandas as pd import numpy as np from spikeinterface.core.job_tools import fix_job_kwargs @@ -49,6 +50,18 @@ class ComputeQualityMetrics(AnalyzerExtension): use_nodepipeline = False need_job_kwargs = True + def __init__(self, sorting_analyzer): + + self._sorting_analyzer = weakref.ref(sorting_analyzer) + + qm_extension = sorting_analyzer.get_extension("quality_metrics") + if qm_extension: + self.params = qm_extension.params + self.data = {"metrics": qm_extension.get_data()} + else: + self.params = {} + self.data = {"metrics": pd.DataFrame(index=sorting_analyzer.sorting.unit_ids)} + def _set_params(self, metric_names=None, qm_params=None, peak_sign=None, seed=None, skip_pc_metrics=False): if metric_names is None: metric_names = list(_misc_metric_name_to_func.keys()) @@ -71,8 +84,14 @@ def _set_params(self, metric_names=None, qm_params=None, peak_sign=None, seed=No if "peak_sign" in qm_params_[k] and peak_sign is not None: qm_params_[k]["peak_sign"] = peak_sign + try: + existing_metric_names = self.sorting_analyzer.get_extension("quality_metrics").params.get("metric_names") + metric_names_for_params = np.concatenate((existing_metric_names, metric_names)) + except: + metric_names_for_params = metric_names + params = dict( - metric_names=[str(name) for name in np.unique(metric_names)], + metric_names=[str(name) for name in np.unique(metric_names_for_params)], peak_sign=peak_sign, seed=seed, qm_params=qm_params_, @@ -89,8 +108,6 @@ def _select_extension_data(self, unit_ids): def _merge_extension_data( self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs ): - import pandas as pd - old_metrics = self.data["metrics"] all_unit_ids = new_sorting_analyzer.unit_ids @@ -134,7 +151,7 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job import pandas as pd - metrics = pd.DataFrame(index=unit_ids) + metrics = self.data["metrics"] # simple metrics not based on PCs for metric_name in metric_names: diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index bb222200e9..e34c15c936 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -1,6 +1,7 @@ import pytest from pathlib import Path import numpy as np +from copy import deepcopy from spikeinterface.core import ( NumpySorting, synthetize_spike_train_bad_isi, @@ -47,6 +48,64 @@ job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") +def test_compute_new_quality_metrics(small_sorting_analyzer): + """ + Computes quality metrics then computes a subset of quality metrics, and checks + that the old quality metrics are not deleted. + """ + + qm_params = { + "presence_ratio": {"bin_duration_s": 0.1}, + "amplitude_cutoff": {"num_histogram_bins": 3}, + "firing_range": {"bin_size_s": 1}, + } + + small_sorting_analyzer.compute( + {"quality_metrics": {"metric_names": list(qm_params.keys()), "qm_params": qm_params}} + ) + small_sorting_analyzer.compute({"quality_metrics": {"metric_names": ["snr"]}}) + + quality_metric_extension = small_sorting_analyzer.get_extension("quality_metrics") + + # Check old metrics are not deleted and the new one is added to the data and metadata + assert list(quality_metric_extension.get_data().keys()) == [ + "amplitude_cutoff", + "firing_range", + "presence_ratio", + "snr", + ] + assert list(quality_metric_extension.params.get("metric_names")) == [ + "amplitude_cutoff", + "firing_range", + "presence_ratio", + "snr", + ] + + # check that, when parameters are changed, the data and metadata are updated + old_snr_data = deepcopy(quality_metric_extension.get_data()["snr"].values) + small_sorting_analyzer.compute( + {"quality_metrics": {"metric_names": ["snr"], "qm_params": {"snr": {"peak_mode": "peak_to_peak"}}}} + ) + new_quality_metric_extension = small_sorting_analyzer.get_extension("quality_metrics") + new_snr_data = new_quality_metric_extension.get_data()["snr"].values + + assert np.all(old_snr_data != new_snr_data) + assert new_quality_metric_extension.params["qm_params"]["snr"]["peak_mode"] == "peak_to_peak" + + # check that all quality metrics are deleted when parents are recomputed, even after + # recomputation + extensions_to_compute = { + "templates": {"operators": ["average", "median"]}, + "spike_amplitudes": {}, + "spike_locations": {}, + "principal_components": {}, + } + + small_sorting_analyzer.compute(extensions_to_compute) + + assert small_sorting_analyzer.get_extension("quality_metrics") is None + + def test_unit_structure_in_output(small_sorting_analyzer): qm_params = { From 63713db9b026c6ab55a2e452396a860b78c9934a Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Wed, 7 Aug 2024 17:16:51 +0100 Subject: [PATCH 02/17] Change where pandas is imported --- .../qualitymetrics/quality_metric_calculator.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index 25b8cc7c05..b34407027b 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -6,7 +6,6 @@ import warnings from copy import deepcopy -import pandas as pd import numpy as np from spikeinterface.core.job_tools import fix_job_kwargs @@ -60,7 +59,7 @@ def __init__(self, sorting_analyzer): self.data = {"metrics": qm_extension.get_data()} else: self.params = {} - self.data = {"metrics": pd.DataFrame(index=sorting_analyzer.sorting.unit_ids)} + self.data = {"metrics": None} def _set_params(self, metric_names=None, qm_params=None, peak_sign=None, seed=None, skip_pc_metrics=False): if metric_names is None: @@ -152,6 +151,8 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job import pandas as pd metrics = self.data["metrics"] + if metrics is None: + metrics = pd.DataFrame(index=sorting_analyzer.sorting.unit_ids) # simple metrics not based on PCs for metric_name in metric_names: From 01abc84370c2bfc8acd81523ea18c2e6e15b03de Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Thu, 8 Aug 2024 09:14:47 +0100 Subject: [PATCH 03/17] Replace try/except with some ifs for metric names in params --- .../qualitymetrics/quality_metric_calculator.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index b34407027b..ced1eedbd2 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -83,11 +83,12 @@ def _set_params(self, metric_names=None, qm_params=None, peak_sign=None, seed=No if "peak_sign" in qm_params_[k] and peak_sign is not None: qm_params_[k]["peak_sign"] = peak_sign - try: - existing_metric_names = self.sorting_analyzer.get_extension("quality_metrics").params.get("metric_names") - metric_names_for_params = np.concatenate((existing_metric_names, metric_names)) - except: - metric_names_for_params = metric_names + metric_names_for_params = metric_names + qm_extension = self.sorting_analyzer.get_extension("quality_metrics") + if qm_extension: + existing_metric_names = qm_extension.params.get("metric_names") + if existing_metric_names is not None: + metric_names_for_params.extend(existing_metric_names) params = dict( metric_names=[str(name) for name in np.unique(metric_names_for_params)], From e6b394115c52c370af8c8411186f97d8e99f24cc Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Thu, 8 Aug 2024 14:15:09 +0100 Subject: [PATCH 04/17] Fix problem with loading sorting_analyzer with qms --- .../qualitymetrics/quality_metric_calculator.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index ced1eedbd2..708498e3fa 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -53,10 +53,11 @@ def __init__(self, sorting_analyzer): self._sorting_analyzer = weakref.ref(sorting_analyzer) - qm_extension = sorting_analyzer.get_extension("quality_metrics") - if qm_extension: - self.params = qm_extension.params - self.data = {"metrics": qm_extension.get_data()} + qm_class = sorting_analyzer.extensions.get("quality_metrics") + + if qm_class: + self.params = qm_class.params + self.data = {"metrics": qm_class.get_data()} else: self.params = {} self.data = {"metrics": None} From 762e8faec7fca93379f9be1bf9be398eb65d4774 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Thu, 8 Aug 2024 15:49:33 +0100 Subject: [PATCH 05/17] update if statement --- src/spikeinterface/qualitymetrics/quality_metric_calculator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index 708498e3fa..c3c95a2f54 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -86,7 +86,7 @@ def _set_params(self, metric_names=None, qm_params=None, peak_sign=None, seed=No metric_names_for_params = metric_names qm_extension = self.sorting_analyzer.get_extension("quality_metrics") - if qm_extension: + if qm_extension is not None: existing_metric_names = qm_extension.params.get("metric_names") if existing_metric_names is not None: metric_names_for_params.extend(existing_metric_names) From 6dcf0b468f5b9daf790e9f9bb6d9b5c3a4e17f37 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Fri, 6 Sep 2024 16:15:06 +0100 Subject: [PATCH 06/17] Move bulk of calc from init to set_params and run --- .../qualitymetrics/misc_metrics.py | 13 +++++ .../quality_metric_calculator.py | 52 +++++++++++-------- 2 files changed, 42 insertions(+), 23 deletions(-) diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index 2de31ad750..8dfd41cf88 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -69,6 +69,9 @@ def compute_num_spikes(sorting_analyzer, unit_ids=None, **kwargs): return num_spikes +_default_params["num_spikes"] = {} + + def compute_firing_rates(sorting_analyzer, unit_ids=None): """ Compute the firing rate across segments. @@ -98,6 +101,9 @@ def compute_firing_rates(sorting_analyzer, unit_ids=None): return firing_rates +_default_params["firing_rate"] = {} + + def compute_presence_ratios(sorting_analyzer, bin_duration_s=60.0, mean_fr_ratio_thresh=0.0, unit_ids=None): """ Calculate the presence ratio, the fraction of time the unit is firing above a certain threshold. @@ -1550,3 +1556,10 @@ def compute_sd_ratio( sd_ratio[unit_id] = unit_std / std_noise return sd_ratio + + +_default_params["sd_ratio"] = dict( + censored_period_ms=4.0, + correct_for_drift=True, + correct_for_template_itself=True, +) diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index c3c95a2f54..b85d3bcfd3 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -2,7 +2,6 @@ from __future__ import annotations -import weakref import warnings from copy import deepcopy @@ -30,8 +29,10 @@ class ComputeQualityMetrics(AnalyzerExtension): qm_params : dict or None Dictionary with parameters for quality metrics calculation. Default parameters can be obtained with: `si.qualitymetrics.get_default_qm_params()` - skip_pc_metrics : bool + skip_pc_metrics : bool, default: False If True, PC metrics computation is skipped. + delete_existing_metrics : bool, default: False + If True, deletes any quality_metrics attached to the `sorting_analyzer` Returns ------- @@ -49,20 +50,16 @@ class ComputeQualityMetrics(AnalyzerExtension): use_nodepipeline = False need_job_kwargs = True - def __init__(self, sorting_analyzer): - - self._sorting_analyzer = weakref.ref(sorting_analyzer) - - qm_class = sorting_analyzer.extensions.get("quality_metrics") - - if qm_class: - self.params = qm_class.params - self.data = {"metrics": qm_class.get_data()} - else: - self.params = {} - self.data = {"metrics": None} + def _set_params( + self, + metric_names=None, + qm_params=None, + peak_sign=None, + seed=None, + skip_pc_metrics=False, + delete_existing_metrics=False, + ): - def _set_params(self, metric_names=None, qm_params=None, peak_sign=None, seed=None, skip_pc_metrics=False): if metric_names is None: metric_names = list(_misc_metric_name_to_func.keys()) # if PC is available, PC metrics are automatically added to the list @@ -84,15 +81,17 @@ def _set_params(self, metric_names=None, qm_params=None, peak_sign=None, seed=No if "peak_sign" in qm_params_[k] and peak_sign is not None: qm_params_[k]["peak_sign"] = peak_sign - metric_names_for_params = metric_names + all_metric_names = metric_names qm_extension = self.sorting_analyzer.get_extension("quality_metrics") - if qm_extension is not None: - existing_metric_names = qm_extension.params.get("metric_names") - if existing_metric_names is not None: - metric_names_for_params.extend(existing_metric_names) + if delete_existing_metrics is False and qm_extension is not None: + existing_params = qm_extension.params + for metric_name in existing_params["metric_names"]: + if metric_name not in metric_names: + all_metric_names.append(metric_name) + qm_params_[metric_name] = existing_params["qm_params"][metric_name] params = dict( - metric_names=[str(name) for name in np.unique(metric_names_for_params)], + metric_names=[str(name) for name in np.unique(all_metric_names)], peak_sign=peak_sign, seed=seed, qm_params=qm_params_, @@ -152,7 +151,7 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job import pandas as pd - metrics = self.data["metrics"] + metrics = self.data.get("metrics") if metrics is None: metrics = pd.DataFrame(index=sorting_analyzer.sorting.unit_ids) @@ -204,11 +203,18 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job return metrics - def _run(self, verbose=False, **job_kwargs): + def _run(self, verbose=False, delete_existing_metrics=False, **job_kwargs): self.data["metrics"] = self._compute_metrics( sorting_analyzer=self.sorting_analyzer, unit_ids=None, verbose=verbose, **job_kwargs ) + qm_extension = self.sorting_analyzer.get_extension("quality_metrics") + if delete_existing_metrics is False and qm_extension is not None: + existing_metrics = qm_extension.get_data() + for metric_name, metric_data in existing_metrics.items(): + if metric_name not in self.data["metrics"]: + self.data["metrics"][metric_name] = metric_data + def _get_data(self): return self.data["metrics"] From 4ad6a884c3cba86a4eb0f2f1979937a25ed4f0bf Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Sat, 7 Sep 2024 12:38:04 +0100 Subject: [PATCH 07/17] Add template_metrics --- .../postprocessing/template_metrics.py | 36 +++++++++++++++++-- .../postprocessing/tests/conftest.py | 33 +++++++++++++++++ .../tests/test_template_metrics.py | 30 ++++++++++++++++ .../quality_metric_calculator.py | 4 +-- 4 files changed, 98 insertions(+), 5 deletions(-) create mode 100644 src/spikeinterface/postprocessing/tests/conftest.py diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index e16bd9ad27..57d8fd5839 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -64,6 +64,8 @@ class ComputeTemplateMetrics(AnalyzerExtension): For more on generating a ChannelSparsity, see the `~spikeinterface.compute_sparsity()` function. include_multi_channel_metrics : bool, default: False Whether to compute multi-channel metrics + delete_existing_metrics : bool, default: False + If True, deletes any quality_metrics attached to the `sorting_analyzer` metrics_kwargs : dict Additional arguments to pass to the metric functions. Including: * recovery_window_ms: the window in ms after the peak to compute the recovery_slope, default: 0.7 @@ -111,8 +113,10 @@ def _set_params( sparsity=None, metrics_kwargs=None, include_multi_channel_metrics=False, + delete_existing_metrics=False, **other_kwargs, ): + import pandas as pd # TODO alessio can you check this : this used to be in the function but now we have ComputeTemplateMetrics.function_factory() if include_multi_channel_metrics or ( @@ -140,9 +144,30 @@ def _set_params( else: metrics_kwargs_ = _default_function_kwargs.copy() metrics_kwargs_.update(metrics_kwargs) + print(metrics_kwargs_) + + all_metric_names = metric_names + tm_extension = self.sorting_analyzer.get_extension("template_metrics") + if delete_existing_metrics is False and tm_extension is not None: + existing_metric_names = tm_extension.params["metric_names"] + existing_params = tm_extension.params["metrics_kwargs"] + + # checks that existing metrics were calculated using the same params + if existing_params != metrics_kwargs_: + warnings.warn( + "The parameters used to calculate the previous template metrics are different than those used now. Deleting previous template metrics..." + ) + self.sorting_analyzer.get_extension("template_metrics").data["metrics"] = pd.DataFrame( + index=self.sorting_analyzer.unit_ids + ) + existing_metric_names = [] + + for metric_name in existing_metric_names: + if metric_name not in metric_names: + all_metric_names.append(metric_name) params = dict( - metric_names=[str(name) for name in np.unique(metric_names)], + metric_names=[str(name) for name in np.unique(all_metric_names)], sparsity=sparsity, peak_sign=peak_sign, upsampling_factor=int(upsampling_factor), @@ -283,11 +308,18 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job template_metrics.at[index, metric_name] = value return template_metrics - def _run(self, verbose=False): + def _run(self, delete_existing_metrics=False, verbose=False): self.data["metrics"] = self._compute_metrics( sorting_analyzer=self.sorting_analyzer, unit_ids=None, verbose=verbose ) + tm_extension = self.sorting_analyzer.get_extension("template_metrics") + if delete_existing_metrics is False and tm_extension is not None: + existing_metrics = tm_extension.get_data() + for metric_name, metric_data in existing_metrics.items(): + if metric_name not in self.data["metrics"]: + self.data["metrics"][metric_name] = metric_data + def _get_data(self): return self.data["metrics"] diff --git a/src/spikeinterface/postprocessing/tests/conftest.py b/src/spikeinterface/postprocessing/tests/conftest.py new file mode 100644 index 0000000000..51ac8aa250 --- /dev/null +++ b/src/spikeinterface/postprocessing/tests/conftest.py @@ -0,0 +1,33 @@ +import pytest + +from spikeinterface.core import ( + generate_ground_truth_recording, + create_sorting_analyzer, +) + + +def _small_sorting_analyzer(): + recording, sorting = generate_ground_truth_recording( + durations=[2.0], + num_units=10, + seed=1205, + ) + + sorting_analyzer = create_sorting_analyzer(recording=recording, sorting=sorting, format="memory") + + extensions_to_compute = { + "random_spikes": {"seed": 1205}, + "noise_levels": {"seed": 1205}, + "waveforms": {}, + "templates": {"operators": ["average", "median"]}, + "spike_amplitudes": {}, + } + + sorting_analyzer.compute(extensions_to_compute) + + return sorting_analyzer + + +@pytest.fixture(scope="module") +def small_sorting_analyzer(): + return _small_sorting_analyzer() diff --git a/src/spikeinterface/postprocessing/tests/test_template_metrics.py b/src/spikeinterface/postprocessing/tests/test_template_metrics.py index 694aa083cc..f444e12c36 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_metrics.py +++ b/src/spikeinterface/postprocessing/tests/test_template_metrics.py @@ -3,6 +3,36 @@ import pytest +def test_compute_new_template_metrics(small_sorting_analyzer): + """ + Computes template metrics then computes a subset of quality metrics, and checks + that the old quality metrics are not deleted. + + Then computes template metrics with new parameters and checks that old metrics + are deleted. + """ + + small_sorting_analyzer.compute("template_metrics") + small_sorting_analyzer.compute({"template_metrics": {"metric_names": ["exp_decay"]}}) + + template_metric_extension = small_sorting_analyzer.get_extension("template_metrics") + + # Check old metrics are not deleted and the new one is added to the data and metadata + assert "exp_decay" in list(template_metric_extension.get_data().keys()) + assert "half_width" in list(template_metric_extension.get_data().keys()) + + # check that, when parameters are changed, the old metrics are deleted + small_sorting_analyzer.compute( + {"template_metrics": {"metric_names": ["exp_decay"], "metrics_kwargs": {"recovery_window_ms": 0.6}}} + ) + + template_metric_extension = small_sorting_analyzer.get_extension("template_metrics") + + assert "half_width" not in list(template_metric_extension.get_data().keys()) + + assert small_sorting_analyzer.get_extension("quality_metrics") is None + + class TestTemplateMetrics(AnalyzerExtensionCommonTestSuite): @pytest.mark.parametrize( diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index b85d3bcfd3..ebd6439be8 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -151,9 +151,7 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job import pandas as pd - metrics = self.data.get("metrics") - if metrics is None: - metrics = pd.DataFrame(index=sorting_analyzer.sorting.unit_ids) + metrics = pd.DataFrame(index=sorting_analyzer.sorting.unit_ids) # simple metrics not based on PCs for metric_name in metric_names: From 3360022f5cfd38e9f7ac3b43cbfefa257d0b8695 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Sat, 7 Sep 2024 13:20:38 +0100 Subject: [PATCH 08/17] Tests now pass --- src/spikeinterface/postprocessing/template_metrics.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 57d8fd5839..fef35bfc59 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -144,7 +144,6 @@ def _set_params( else: metrics_kwargs_ = _default_function_kwargs.copy() metrics_kwargs_.update(metrics_kwargs) - print(metrics_kwargs_) all_metric_names = metric_names tm_extension = self.sorting_analyzer.get_extension("template_metrics") @@ -160,6 +159,7 @@ def _set_params( self.sorting_analyzer.get_extension("template_metrics").data["metrics"] = pd.DataFrame( index=self.sorting_analyzer.unit_ids ) + self.sorting_analyzer.get_extension("template_metrics").params["metric_names"] = [] existing_metric_names = [] for metric_name in existing_metric_names: @@ -315,9 +315,11 @@ def _run(self, delete_existing_metrics=False, verbose=False): tm_extension = self.sorting_analyzer.get_extension("template_metrics") if delete_existing_metrics is False and tm_extension is not None: - existing_metrics = tm_extension.get_data() - for metric_name, metric_data in existing_metrics.items(): + existing_metrics = tm_extension.params["metric_names"] + + for metric_name in existing_metrics: if metric_name not in self.data["metrics"]: + metric_data = tm_extension.get_data()[metric_name] self.data["metrics"][metric_name] = metric_data def _get_data(self): From 3cc1298d8ecb9ce0dfa487a8f93f3b14f9a6ba90 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Mon, 9 Sep 2024 08:20:08 +0100 Subject: [PATCH 09/17] tests definitely 100% pass now --- src/spikeinterface/qualitymetrics/quality_metric_calculator.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index ebd6439be8..31353df724 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -108,6 +108,8 @@ def _select_extension_data(self, unit_ids): def _merge_extension_data( self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs ): + import pandas as pd + old_metrics = self.data["metrics"] all_unit_ids = new_sorting_analyzer.unit_ids From 689c633a697964a9e18673aae5f2c3528382e716 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Tue, 10 Sep 2024 18:06:06 +0100 Subject: [PATCH 10/17] Update template metrics based on Joe feedback --- .../postprocessing/template_metrics.py | 61 +++++++++++-------- .../tests/test_template_metrics.py | 53 +++++++++++++++- 2 files changed, 88 insertions(+), 26 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index fef35bfc59..15b8c85e38 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -65,7 +65,7 @@ class ComputeTemplateMetrics(AnalyzerExtension): include_multi_channel_metrics : bool, default: False Whether to compute multi-channel metrics delete_existing_metrics : bool, default: False - If True, deletes any quality_metrics attached to the `sorting_analyzer` + If True, any template metrics attached to the `sorting_analyzer` are deleted. If False, any metrics which were previously calculated but are not included in `metric_names` are kept, provided the `metrics_kwargs` are unchanged. metrics_kwargs : dict Additional arguments to pass to the metric functions. Including: * recovery_window_ms: the window in ms after the peak to compute the recovery_slope, default: 0.7 @@ -116,6 +116,7 @@ def _set_params( delete_existing_metrics=False, **other_kwargs, ): + import pandas as pd # TODO alessio can you check this : this used to be in the function but now we have ComputeTemplateMetrics.function_factory() @@ -135,6 +136,10 @@ def _set_params( if include_multi_channel_metrics: metric_names += get_multi_channel_template_metric_names() + # `run` cannot take parameters, so need to find another way to pass this + self.delete_existing_metrics = delete_existing_metrics + self.metric_names = metric_names + if metrics_kwargs is None: metrics_kwargs_ = _default_function_kwargs.copy() if len(other_kwargs) > 0: @@ -145,29 +150,24 @@ def _set_params( metrics_kwargs_ = _default_function_kwargs.copy() metrics_kwargs_.update(metrics_kwargs) - all_metric_names = metric_names tm_extension = self.sorting_analyzer.get_extension("template_metrics") if delete_existing_metrics is False and tm_extension is not None: - existing_metric_names = tm_extension.params["metric_names"] - existing_params = tm_extension.params["metrics_kwargs"] + existing_params = tm_extension.params["metrics_kwargs"] # checks that existing metrics were calculated using the same params if existing_params != metrics_kwargs_: warnings.warn( "The parameters used to calculate the previous template metrics are different than those used now. Deleting previous template metrics..." ) - self.sorting_analyzer.get_extension("template_metrics").data["metrics"] = pd.DataFrame( - index=self.sorting_analyzer.unit_ids - ) - self.sorting_analyzer.get_extension("template_metrics").params["metric_names"] = [] + tm_extension.params["metric_names"] = [] existing_metric_names = [] + else: + existing_metric_names = tm_extension.params["metric_names"] - for metric_name in existing_metric_names: - if metric_name not in metric_names: - all_metric_names.append(metric_name) + metric_names = list(set(existing_metric_names + metric_names)) params = dict( - metric_names=[str(name) for name in np.unique(all_metric_names)], + metric_names=metric_names, sparsity=sparsity, peak_sign=peak_sign, upsampling_factor=int(upsampling_factor), @@ -185,6 +185,7 @@ def _merge_extension_data( ): import pandas as pd + metric_names = self.params["metric_names"] old_metrics = self.data["metrics"] all_unit_ids = new_sorting_analyzer.unit_ids @@ -193,19 +194,20 @@ def _merge_extension_data( metrics = pd.DataFrame(index=all_unit_ids, columns=old_metrics.columns) metrics.loc[not_new_ids, :] = old_metrics.loc[not_new_ids, :] - metrics.loc[new_unit_ids, :] = self._compute_metrics(new_sorting_analyzer, new_unit_ids, verbose, **job_kwargs) + metrics.loc[new_unit_ids, :] = self._compute_metrics( + new_sorting_analyzer, new_unit_ids, verbose, metric_names, **job_kwargs + ) new_data = dict(metrics=metrics) return new_data - def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job_kwargs): + def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metric_names=None, **job_kwargs): """ Compute template metrics. """ import pandas as pd from scipy.signal import resample_poly - metric_names = self.params["metric_names"] sparsity = self.params["sparsity"] peak_sign = self.params["peak_sign"] upsampling_factor = self.params["upsampling_factor"] @@ -308,19 +310,30 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job template_metrics.at[index, metric_name] = value return template_metrics - def _run(self, delete_existing_metrics=False, verbose=False): - self.data["metrics"] = self._compute_metrics( - sorting_analyzer=self.sorting_analyzer, unit_ids=None, verbose=verbose - ) + def _run(self, verbose=False): + + delete_existing_metrics = self.delete_existing_metrics + metric_names = self.metric_names + existing_metrics = [] tm_extension = self.sorting_analyzer.get_extension("template_metrics") - if delete_existing_metrics is False and tm_extension is not None: + if ( + delete_existing_metrics is False + and tm_extension is not None + and tm_extension.data.get("metrics") is not None + ): existing_metrics = tm_extension.params["metric_names"] - for metric_name in existing_metrics: - if metric_name not in self.data["metrics"]: - metric_data = tm_extension.get_data()[metric_name] - self.data["metrics"][metric_name] = metric_data + # compute the metrics which have been specified by the user + computed_metrics = self._compute_metrics( + sorting_analyzer=self.sorting_analyzer, unit_ids=None, verbose=verbose, metric_names=metric_names + ) + + # append the metrics which were previously computed + for metric_name in set(existing_metrics).difference(metric_names): + computed_metrics[metric_name] = tm_extension.data["metrics"][metric_name] + + self.data["metrics"] = computed_metrics def _get_data(self): return self.data["metrics"] diff --git a/src/spikeinterface/postprocessing/tests/test_template_metrics.py b/src/spikeinterface/postprocessing/tests/test_template_metrics.py index f444e12c36..1fa2ac638c 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_metrics.py +++ b/src/spikeinterface/postprocessing/tests/test_template_metrics.py @@ -1,18 +1,22 @@ from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite from spikeinterface.postprocessing import ComputeTemplateMetrics import pytest +import csv def test_compute_new_template_metrics(small_sorting_analyzer): """ - Computes template metrics then computes a subset of quality metrics, and checks - that the old quality metrics are not deleted. + Computes template metrics then computes a subset of template metrics, and checks + that the old template metrics are not deleted. Then computes template metrics with new parameters and checks that old metrics are deleted. """ + # calculate all template metrics small_sorting_analyzer.compute("template_metrics") + + # calculate just exp_decay - this should not delete the previously calculated metrics small_sorting_analyzer.compute({"template_metrics": {"metric_names": ["exp_decay"]}}) template_metric_extension = small_sorting_analyzer.get_extension("template_metrics") @@ -33,6 +37,51 @@ def test_compute_new_template_metrics(small_sorting_analyzer): assert small_sorting_analyzer.get_extension("quality_metrics") is None +def test_save_template_metrics(small_sorting_analyzer, create_cache_folder): + """ + Computes template metrics in binary folder format. Then computes subsets of template + metrics and checks if they are saved correctly. + """ + + from spikeinterface.postprocessing.template_metrics import _single_channel_metric_name_to_func + + small_sorting_analyzer.compute("template_metrics") + + cache_folder = create_cache_folder + output_folder = cache_folder / "sorting_analyzer" + + folder_analyzer = small_sorting_analyzer.save_as(format="binary_folder", folder=output_folder) + template_metrics_filename = output_folder / "extensions" / "template_metrics" / "metrics.csv" + + with open(template_metrics_filename) as metrics_file: + saved_metrics = csv.reader(metrics_file) + metric_names = next(saved_metrics) + + for metric_name in list(_single_channel_metric_name_to_func.keys()): + assert metric_name in metric_names + + folder_analyzer.compute("template_metrics", metric_names=["half_width"], delete_existing_metrics=False) + + with open(template_metrics_filename) as metrics_file: + saved_metrics = csv.reader(metrics_file) + metric_names = next(saved_metrics) + + for metric_name in list(_single_channel_metric_name_to_func.keys()): + assert metric_name in metric_names + + folder_analyzer.compute("template_metrics", metric_names=["half_width"], delete_existing_metrics=True) + + with open(template_metrics_filename) as metrics_file: + saved_metrics = csv.reader(metrics_file) + metric_names = next(saved_metrics) + + for metric_name in list(_single_channel_metric_name_to_func.keys()): + if metric_name == "half_width": + assert metric_name in metric_names + else: + assert metric_name not in metric_names + + class TestTemplateMetrics(AnalyzerExtensionCommonTestSuite): @pytest.mark.parametrize( From 6cbe4dbe9c805aac3fb3ed29a0f50f4623eeab9c Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Tue, 10 Sep 2024 18:25:30 +0100 Subject: [PATCH 11/17] Improve tests for template metrics --- .../tests/test_template_metrics.py | 44 ++++++++++++------- 1 file changed, 28 insertions(+), 16 deletions(-) diff --git a/src/spikeinterface/postprocessing/tests/test_template_metrics.py b/src/spikeinterface/postprocessing/tests/test_template_metrics.py index 1fa2ac638c..8aaad8ffbc 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_metrics.py +++ b/src/spikeinterface/postprocessing/tests/test_template_metrics.py @@ -3,6 +3,10 @@ import pytest import csv +from spikeinterface.postprocessing.template_metrics import _single_channel_metric_name_to_func + +template_metrics = list(_single_channel_metric_name_to_func.keys()) + def test_compute_new_template_metrics(small_sorting_analyzer): """ @@ -13,28 +17,38 @@ def test_compute_new_template_metrics(small_sorting_analyzer): are deleted. """ + # calculate just exp_decay + small_sorting_analyzer.compute({"template_metrics": {"metric_names": ["exp_decay"]}}) + template_metric_extension = small_sorting_analyzer.get_extension("template_metrics") + + assert "exp_decay" in list(template_metric_extension.get_data().keys()) + assert "half_width" not in list(template_metric_extension.get_data().keys()) + # calculate all template metrics small_sorting_analyzer.compute("template_metrics") - - # calculate just exp_decay - this should not delete the previously calculated metrics + # calculate just exp_decay - this should not delete any other metrics small_sorting_analyzer.compute({"template_metrics": {"metric_names": ["exp_decay"]}}) - template_metric_extension = small_sorting_analyzer.get_extension("template_metrics") - # Check old metrics are not deleted and the new one is added to the data and metadata - assert "exp_decay" in list(template_metric_extension.get_data().keys()) - assert "half_width" in list(template_metric_extension.get_data().keys()) + set(template_metrics) == set(template_metric_extension.get_data().keys()) - # check that, when parameters are changed, the old metrics are deleted + # calculate just exp_decay with delete_existing_metrics small_sorting_analyzer.compute( - {"template_metrics": {"metric_names": ["exp_decay"], "metrics_kwargs": {"recovery_window_ms": 0.6}}} + {"template_metrics": {"metric_names": ["exp_decay"], "delete_existing_metrics": True}} ) - template_metric_extension = small_sorting_analyzer.get_extension("template_metrics") + computed_metric_names = template_metric_extension.get_data().keys() - assert "half_width" not in list(template_metric_extension.get_data().keys()) + for metric_name in template_metrics: + if metric_name == "exp_decay": + assert metric_name in computed_metric_names + else: + assert metric_name not in computed_metric_names - assert small_sorting_analyzer.get_extension("quality_metrics") is None + # check that, when parameters are changed, the old metrics are deleted + small_sorting_analyzer.compute( + {"template_metrics": {"metric_names": ["exp_decay"], "metrics_kwargs": {"recovery_window_ms": 0.6}}} + ) def test_save_template_metrics(small_sorting_analyzer, create_cache_folder): @@ -43,8 +57,6 @@ def test_save_template_metrics(small_sorting_analyzer, create_cache_folder): metrics and checks if they are saved correctly. """ - from spikeinterface.postprocessing.template_metrics import _single_channel_metric_name_to_func - small_sorting_analyzer.compute("template_metrics") cache_folder = create_cache_folder @@ -57,7 +69,7 @@ def test_save_template_metrics(small_sorting_analyzer, create_cache_folder): saved_metrics = csv.reader(metrics_file) metric_names = next(saved_metrics) - for metric_name in list(_single_channel_metric_name_to_func.keys()): + for metric_name in template_metrics: assert metric_name in metric_names folder_analyzer.compute("template_metrics", metric_names=["half_width"], delete_existing_metrics=False) @@ -66,7 +78,7 @@ def test_save_template_metrics(small_sorting_analyzer, create_cache_folder): saved_metrics = csv.reader(metrics_file) metric_names = next(saved_metrics) - for metric_name in list(_single_channel_metric_name_to_func.keys()): + for metric_name in template_metrics: assert metric_name in metric_names folder_analyzer.compute("template_metrics", metric_names=["half_width"], delete_existing_metrics=True) @@ -75,7 +87,7 @@ def test_save_template_metrics(small_sorting_analyzer, create_cache_folder): saved_metrics = csv.reader(metrics_file) metric_names = next(saved_metrics) - for metric_name in list(_single_channel_metric_name_to_func.keys()): + for metric_name in template_metrics: if metric_name == "half_width": assert metric_name in metric_names else: From 4589c6efae21d540345ad1ba858e53828d441e48 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Wed, 11 Sep 2024 09:28:00 +0100 Subject: [PATCH 12/17] Add ordering and propogate through params --- .../postprocessing/template_metrics.py | 18 +++++++++++------- .../tests/test_template_metrics.py | 11 +++++++++++ 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 15b8c85e38..062b0bd76b 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -137,8 +137,7 @@ def _set_params( metric_names += get_multi_channel_template_metric_names() # `run` cannot take parameters, so need to find another way to pass this - self.delete_existing_metrics = delete_existing_metrics - self.metric_names = metric_names + metric_names_to_compute = metric_names if metrics_kwargs is None: metrics_kwargs_ = _default_function_kwargs.copy() @@ -164,7 +163,10 @@ def _set_params( else: existing_metric_names = tm_extension.params["metric_names"] - metric_names = list(set(existing_metric_names + metric_names)) + existing_metric_names_propogated = [ + metric_name for metric_name in existing_metric_names if metric_name not in metric_names_to_compute + ] + metric_names = metric_names_to_compute + existing_metric_names_propogated params = dict( metric_names=metric_names, @@ -172,6 +174,8 @@ def _set_params( peak_sign=peak_sign, upsampling_factor=int(upsampling_factor), metrics_kwargs=metrics_kwargs_, + delete_existing_metrics=delete_existing_metrics, + metric_names_to_compute=metric_names_to_compute, ) return params @@ -312,8 +316,8 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri def _run(self, verbose=False): - delete_existing_metrics = self.delete_existing_metrics - metric_names = self.metric_names + delete_existing_metrics = self.params["delete_existing_metrics"] + metric_names_to_compute = self.params["metric_names_to_compute"] existing_metrics = [] tm_extension = self.sorting_analyzer.get_extension("template_metrics") @@ -326,11 +330,11 @@ def _run(self, verbose=False): # compute the metrics which have been specified by the user computed_metrics = self._compute_metrics( - sorting_analyzer=self.sorting_analyzer, unit_ids=None, verbose=verbose, metric_names=metric_names + sorting_analyzer=self.sorting_analyzer, unit_ids=None, verbose=verbose, metric_names=metric_names_to_compute ) # append the metrics which were previously computed - for metric_name in set(existing_metrics).difference(metric_names): + for metric_name in set(existing_metrics).difference(metric_names_to_compute): computed_metrics[metric_name] = tm_extension.data["metrics"][metric_name] self.data["metrics"] = computed_metrics diff --git a/src/spikeinterface/postprocessing/tests/test_template_metrics.py b/src/spikeinterface/postprocessing/tests/test_template_metrics.py index 8aaad8ffbc..5056d4ff2a 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_metrics.py +++ b/src/spikeinterface/postprocessing/tests/test_template_metrics.py @@ -51,6 +51,17 @@ def test_compute_new_template_metrics(small_sorting_analyzer): ) +def test_metric_names_in_same_order(small_sorting_analyzer): + """ + Computes sepecified template metrics and checks order is propogated. + """ + specified_metric_names = ["peak_trough_ratio", "num_negative_peaks", "half_width"] + small_sorting_analyzer.compute("template_metrics", metric_names=specified_metric_names) + tm_keys = small_sorting_analyzer.get_extension("template_metrics").get_data().keys() + for i in range(3): + assert specified_metric_names[i] == tm_keys[i] + + def test_save_template_metrics(small_sorting_analyzer, create_cache_folder): """ Computes template metrics in binary folder format. Then computes subsets of template From accf40a30660900dabeadd5f8b0d17190ed6d3a4 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Wed, 11 Sep 2024 10:19:31 +0100 Subject: [PATCH 13/17] Update quality metrics --- .../postprocessing/template_metrics.py | 24 ++-- .../quality_metric_calculator.py | 64 ++++++--- .../qualitymetrics/quality_metric_list.py | 26 ++++ .../tests/test_metrics_functions.py | 122 ++++++++++++++++-- 4 files changed, 192 insertions(+), 44 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 062b0bd76b..5f4c1e904b 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -136,9 +136,6 @@ def _set_params( if include_multi_channel_metrics: metric_names += get_multi_channel_template_metric_names() - # `run` cannot take parameters, so need to find another way to pass this - metric_names_to_compute = metric_names - if metrics_kwargs is None: metrics_kwargs_ = _default_function_kwargs.copy() if len(other_kwargs) > 0: @@ -149,6 +146,7 @@ def _set_params( metrics_kwargs_ = _default_function_kwargs.copy() metrics_kwargs_.update(metrics_kwargs) + metrics_to_compute = metric_names tm_extension = self.sorting_analyzer.get_extension("template_metrics") if delete_existing_metrics is False and tm_extension is not None: @@ -164,9 +162,9 @@ def _set_params( existing_metric_names = tm_extension.params["metric_names"] existing_metric_names_propogated = [ - metric_name for metric_name in existing_metric_names if metric_name not in metric_names_to_compute + metric_name for metric_name in existing_metric_names if metric_name not in metrics_to_compute ] - metric_names = metric_names_to_compute + existing_metric_names_propogated + metric_names = metrics_to_compute + existing_metric_names_propogated params = dict( metric_names=metric_names, @@ -175,7 +173,7 @@ def _set_params( upsampling_factor=int(upsampling_factor), metrics_kwargs=metrics_kwargs_, delete_existing_metrics=delete_existing_metrics, - metric_names_to_compute=metric_names_to_compute, + metrics_to_compute=metrics_to_compute, ) return params @@ -317,7 +315,12 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri def _run(self, verbose=False): delete_existing_metrics = self.params["delete_existing_metrics"] - metric_names_to_compute = self.params["metric_names_to_compute"] + metrics_to_compute = self.params["metrics_to_compute"] + + # compute the metrics which have been specified by the user + computed_metrics = self._compute_metrics( + sorting_analyzer=self.sorting_analyzer, unit_ids=None, verbose=verbose, metric_names=metrics_to_compute + ) existing_metrics = [] tm_extension = self.sorting_analyzer.get_extension("template_metrics") @@ -328,13 +331,8 @@ def _run(self, verbose=False): ): existing_metrics = tm_extension.params["metric_names"] - # compute the metrics which have been specified by the user - computed_metrics = self._compute_metrics( - sorting_analyzer=self.sorting_analyzer, unit_ids=None, verbose=verbose, metric_names=metric_names_to_compute - ) - # append the metrics which were previously computed - for metric_name in set(existing_metrics).difference(metric_names_to_compute): + for metric_name in set(existing_metrics).difference(metrics_to_compute): computed_metrics[metric_name] = tm_extension.data["metrics"][metric_name] self.data["metrics"] = computed_metrics diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index 31353df724..1c7483212a 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -11,7 +11,12 @@ from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension -from .quality_metric_list import compute_pc_metrics, _misc_metric_name_to_func, _possible_pc_metric_names +from .quality_metric_list import ( + compute_pc_metrics, + _misc_metric_name_to_func, + _possible_pc_metric_names, + compute_name_to_column_names, +) from .misc_metrics import _default_params as misc_metrics_params from .pca_metrics import _default_params as pca_metrics_params @@ -32,7 +37,7 @@ class ComputeQualityMetrics(AnalyzerExtension): skip_pc_metrics : bool, default: False If True, PC metrics computation is skipped. delete_existing_metrics : bool, default: False - If True, deletes any quality_metrics attached to the `sorting_analyzer` + If True, any quality metrics attached to the `sorting_analyzer` are deleted. If False, any metrics which were previously calculated but are not included in `metric_names` are kept. Returns ------- @@ -81,21 +86,24 @@ def _set_params( if "peak_sign" in qm_params_[k] and peak_sign is not None: qm_params_[k]["peak_sign"] = peak_sign - all_metric_names = metric_names + metrics_to_compute = metric_names qm_extension = self.sorting_analyzer.get_extension("quality_metrics") if delete_existing_metrics is False and qm_extension is not None: - existing_params = qm_extension.params - for metric_name in existing_params["metric_names"]: - if metric_name not in metric_names: - all_metric_names.append(metric_name) - qm_params_[metric_name] = existing_params["qm_params"][metric_name] + + existing_metric_names = qm_extension.params["metric_names"] + existing_metric_names_propogated = [ + metric_name for metric_name in existing_metric_names if metric_name not in metrics_to_compute + ] + metric_names = metrics_to_compute + existing_metric_names_propogated params = dict( - metric_names=[str(name) for name in np.unique(all_metric_names)], + metric_names=metric_names, peak_sign=peak_sign, seed=seed, qm_params=qm_params_, skip_pc_metrics=skip_pc_metrics, + delete_existing_metrics=delete_existing_metrics, + metrics_to_compute=metrics_to_compute, ) return params @@ -123,11 +131,11 @@ def _merge_extension_data( new_data = dict(metrics=metrics) return new_data - def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job_kwargs): + def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metric_names=None, **job_kwargs): """ Compute quality metrics. """ - metric_names = self.params["metric_names"] + qm_params = self.params["qm_params"] # sparsity = self.params["sparsity"] seed = self.params["seed"] @@ -203,17 +211,35 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job return metrics - def _run(self, verbose=False, delete_existing_metrics=False, **job_kwargs): - self.data["metrics"] = self._compute_metrics( - sorting_analyzer=self.sorting_analyzer, unit_ids=None, verbose=verbose, **job_kwargs + def _run(self, verbose=False, **job_kwargs): + + metrics_to_compute = self.params["metrics_to_compute"] + delete_existing_metrics = self.params["delete_existing_metrics"] + + computed_metrics = self._compute_metrics( + sorting_analyzer=self.sorting_analyzer, + unit_ids=None, + verbose=verbose, + metric_names=metrics_to_compute, + **job_kwargs, ) + existing_metrics = [] qm_extension = self.sorting_analyzer.get_extension("quality_metrics") - if delete_existing_metrics is False and qm_extension is not None: - existing_metrics = qm_extension.get_data() - for metric_name, metric_data in existing_metrics.items(): - if metric_name not in self.data["metrics"]: - self.data["metrics"][metric_name] = metric_data + if ( + delete_existing_metrics is False + and qm_extension is not None + and qm_extension.data.get("metrics") is not None + ): + existing_metrics = qm_extension.params["metric_names"] + + # append the metrics which were previously computed + for metric_name in set(existing_metrics).difference(metrics_to_compute): + # some metrics names produce data columns with other names. This deals with that. + for column_name in compute_name_to_column_names[metric_name]: + computed_metrics[column_name] = qm_extension.data["metrics"][column_name] + + self.data["metrics"] = computed_metrics def _get_data(self): return self.data["metrics"] diff --git a/src/spikeinterface/qualitymetrics/quality_metric_list.py b/src/spikeinterface/qualitymetrics/quality_metric_list.py index 140ad87a8b..375dd320ae 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_list.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_list.py @@ -53,3 +53,29 @@ "drift": compute_drift_metrics, "sd_ratio": compute_sd_ratio, } + +# a dict converting the name of the metric for computation to the output of that computation +compute_name_to_column_names = { + "num_spikes": ["num_spikes"], + "firing_rate": ["firing_rate"], + "presence_ratio": ["presence_ratio"], + "snr": ["snr"], + "isi_violation": ["isi_violations_ratio", "isi_violations_count"], + "rp_violation": ["rp_violations", "rp_contamination"], + "sliding_rp_violation": ["sliding_rp_violation"], + "amplitude_cutoff": ["amplitude_cutoff"], + "amplitude_median": ["amplitude_median"], + "amplitude_cv": ["amplitude_cv_median", "amplitude_cv_range"], + "synchrony": ["sync_spike_2", "sync_spike_4", "sync_spike_8"], + "firing_range": ["firing_range"], + "drift": ["drift_ptp", "drift_std", "drift_mad"], + "sd_ratio": ["sd_ratio"], + "isolation_distance": ["isolation_distance"], + "l_ratio": ["l_ratio"], + "d_prime": ["d_prime"], + "nearest_neighbor": ["nn_hit_rate", "nn_miss_rate"], + "nn_isolation": ["nn_isolation", "nn_unit_id"], + "nn_noise_overlap": ["nn_noise_overlap"], + "silhouette": ["silhouette"], + "silhouette_full": ["silhouette_full"], +} diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index e34c15c936..77909798a3 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -2,6 +2,7 @@ from pathlib import Path import numpy as np from copy import deepcopy +import csv from spikeinterface.core import ( NumpySorting, synthetize_spike_train_bad_isi, @@ -42,6 +43,7 @@ compute_quality_metrics, ) + from spikeinterface.core.basesorting import minimum_spike_dtype @@ -60,6 +62,12 @@ def test_compute_new_quality_metrics(small_sorting_analyzer): "firing_range": {"bin_size_s": 1}, } + small_sorting_analyzer.compute({"quality_metrics": {"metric_names": ["snr"]}}) + qm_extension = small_sorting_analyzer.get_extension("quality_metrics") + calculated_metrics = list(qm_extension.get_data().keys()) + + assert calculated_metrics == ["snr"] + small_sorting_analyzer.compute( {"quality_metrics": {"metric_names": list(qm_params.keys()), "qm_params": qm_params}} ) @@ -68,18 +76,22 @@ def test_compute_new_quality_metrics(small_sorting_analyzer): quality_metric_extension = small_sorting_analyzer.get_extension("quality_metrics") # Check old metrics are not deleted and the new one is added to the data and metadata - assert list(quality_metric_extension.get_data().keys()) == [ - "amplitude_cutoff", - "firing_range", - "presence_ratio", - "snr", - ] - assert list(quality_metric_extension.params.get("metric_names")) == [ - "amplitude_cutoff", - "firing_range", - "presence_ratio", - "snr", - ] + assert set(list(quality_metric_extension.get_data().keys())) == set( + [ + "amplitude_cutoff", + "firing_range", + "presence_ratio", + "snr", + ] + ) + assert set(list(quality_metric_extension.params.get("metric_names"))) == set( + [ + "amplitude_cutoff", + "firing_range", + "presence_ratio", + "snr", + ] + ) # check that, when parameters are changed, the data and metadata are updated old_snr_data = deepcopy(quality_metric_extension.get_data()["snr"].values) @@ -106,6 +118,92 @@ def test_compute_new_quality_metrics(small_sorting_analyzer): assert small_sorting_analyzer.get_extension("quality_metrics") is None +def test_metric_names_in_same_order(small_sorting_analyzer): + """ + Computes sepecified quality metrics and checks order is propogated. + """ + specified_metric_names = ["firing_range", "snr", "amplitude_cutoff"] + small_sorting_analyzer.compute("quality_metrics", metric_names=specified_metric_names) + qm_keys = small_sorting_analyzer.get_extension("quality_metrics").get_data().keys() + for i in range(3): + assert specified_metric_names[i] == qm_keys[i] + + +def test_save_quality_metrics(small_sorting_analyzer, create_cache_folder): + """ + Computes quality metrics in binary folder format. Then computes subsets of quality + metrics and checks if they are saved correctly. + """ + + # can't use _misc_metric_name_to_func as some functions compute several qms + # e.g. isi_violation and synchrony + quality_metrics = [ + "num_spikes", + "firing_rate", + "presence_ratio", + "snr", + "isi_violations_ratio", + "isi_violations_count", + "rp_contamination", + "rp_violations", + "sliding_rp_violation", + "amplitude_cutoff", + "amplitude_median", + "amplitude_cv_median", + "amplitude_cv_range", + "sync_spike_2", + "sync_spike_4", + "sync_spike_8", + "firing_range", + "drift_ptp", + "drift_std", + "drift_mad", + "sd_ratio", + "isolation_distance", + "l_ratio", + "d_prime", + "silhouette", + "nn_hit_rate", + "nn_miss_rate", + ] + + small_sorting_analyzer.compute("quality_metrics") + + cache_folder = create_cache_folder + output_folder = cache_folder / "sorting_analyzer" + + folder_analyzer = small_sorting_analyzer.save_as(format="binary_folder", folder=output_folder) + quality_metrics_filename = output_folder / "extensions" / "quality_metrics" / "metrics.csv" + + with open(quality_metrics_filename) as metrics_file: + saved_metrics = csv.reader(metrics_file) + metric_names = next(saved_metrics) + + for metric_name in quality_metrics: + assert metric_name in metric_names + + folder_analyzer.compute("quality_metrics", metric_names=["snr"], delete_existing_metrics=False) + + with open(quality_metrics_filename) as metrics_file: + saved_metrics = csv.reader(metrics_file) + metric_names = next(saved_metrics) + + for metric_name in quality_metrics: + assert metric_name in metric_names + + folder_analyzer.compute("quality_metrics", metric_names=["snr"], delete_existing_metrics=True) + + with open(quality_metrics_filename) as metrics_file: + saved_metrics = csv.reader(metrics_file) + metric_names = next(saved_metrics) + + for metric_name in quality_metrics: + if metric_name == "snr": + assert metric_name in metric_names + else: + assert metric_name not in metric_names + + def test_unit_structure_in_output(small_sorting_analyzer): qm_params = { From c1f0b2a8ae1996f30b2612b92aa3fd48e50dba3a Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Wed, 11 Sep 2024 10:34:43 +0100 Subject: [PATCH 14/17] Update merge_extension_data for quality_metrics --- .../qualitymetrics/quality_metric_calculator.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index 1c7483212a..a143ac3562 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -63,6 +63,7 @@ def _set_params( seed=None, skip_pc_metrics=False, delete_existing_metrics=False, + metrics_to_compute=None, ): if metric_names is None: @@ -118,6 +119,7 @@ def _merge_extension_data( ): import pandas as pd + metric_names = self.params["metric_names"] old_metrics = self.data["metrics"] all_unit_ids = new_sorting_analyzer.unit_ids @@ -126,7 +128,9 @@ def _merge_extension_data( metrics = pd.DataFrame(index=all_unit_ids, columns=old_metrics.columns) metrics.loc[not_new_ids, :] = old_metrics.loc[not_new_ids, :] - metrics.loc[new_unit_ids, :] = self._compute_metrics(new_sorting_analyzer, new_unit_ids, verbose, **job_kwargs) + metrics.loc[new_unit_ids, :] = self._compute_metrics( + new_sorting_analyzer, new_unit_ids, verbose, metric_names, **job_kwargs + ) new_data = dict(metrics=metrics) return new_data From 9fb7f344b578e0bf2a9436608757fdffaa008e06 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Wed, 11 Sep 2024 10:59:35 +0100 Subject: [PATCH 15/17] use sa unit ids and switch order of id indep test --- .../qualitymetrics/quality_metric_calculator.py | 2 +- .../qualitymetrics/tests/test_metrics_functions.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index 123293f313..8a754fc7da 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -165,7 +165,7 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri import pandas as pd - metrics = pd.DataFrame(index=sorting_analyzer.sorting.unit_ids) + metrics = pd.DataFrame(index=sorting_analyzer.unit_ids) # simple metrics not based on PCs for metric_name in metric_names: diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index 77909798a3..ee5d5849b3 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -286,10 +286,10 @@ def test_unit_id_order_independence(small_sorting_analyzer): small_sorting_analyzer_2, metric_names=get_quality_metric_list(), qm_params=qm_params ) - for metric, metric_1_data in quality_metrics_1.items(): - assert quality_metrics_2[metric][2] == metric_1_data["#3"] - assert quality_metrics_2[metric][7] == metric_1_data["#9"] - assert quality_metrics_2[metric][1] == metric_1_data["#4"] + for metric, metric_2_data in quality_metrics_2.items(): + assert quality_metrics_1[metric]["#3"] == metric_2_data[2] + assert quality_metrics_1[metric]["#9"] == metric_2_data[7] + assert quality_metrics_1[metric]["#4"] == metric_2_data[1] def _sorting_analyzer_simple(): From 406f99bf7d2f68b68db3e4535ba6e004331de419 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Thu, 12 Sep 2024 08:39:56 +0100 Subject: [PATCH 16/17] update warning to include metric names --- src/spikeinterface/postprocessing/template_metrics.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index d05e3ae7ef..0d0d633c04 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -152,7 +152,9 @@ def _set_params( # checks that existing metrics were calculated using the same params if existing_params != metrics_kwargs_: warnings.warn( - "The parameters used to calculate the previous template metrics are different than those used now. Deleting previous template metrics..." + f"The parameters used to calculate the previous template metrics are different" + f"than those used now.\nPrevious parameters: {existing_params}\nCurrent " + f"parameters: {metrics_kwargs_}\nDeleting previous template metrics..." ) tm_extension.params["metric_names"] = [] existing_metric_names = [] From 6c8889d74d5e8c4c2b5fa073f9b5cfcd7b9141b6 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 13 Sep 2024 12:39:05 +0200 Subject: [PATCH 17/17] Update src/spikeinterface/qualitymetrics/quality_metric_calculator.py Co-authored-by: Garcia Samuel --- src/spikeinterface/qualitymetrics/quality_metric_calculator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index 8a754fc7da..52eb56c4ee 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -165,7 +165,7 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri import pandas as pd - metrics = pd.DataFrame(index=sorting_analyzer.unit_ids) + metrics = pd.DataFrame(index=unit_ids) # simple metrics not based on PCs for metric_name in metric_names: