diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 45ba55dee4..0d0d633c04 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -62,6 +62,8 @@ class ComputeTemplateMetrics(AnalyzerExtension): For more on generating a ChannelSparsity, see the `~spikeinterface.compute_sparsity()` function. include_multi_channel_metrics : bool, default: False Whether to compute multi-channel metrics + delete_existing_metrics : bool, default: False + If True, any template metrics attached to the `sorting_analyzer` are deleted. If False, any metrics which were previously calculated but are not included in `metric_names` are kept, provided the `metrics_kwargs` are unchanged. metrics_kwargs : dict Additional arguments to pass to the metric functions. Including: * recovery_window_ms: the window in ms after the peak to compute the recovery_slope, default: 0.7 @@ -109,9 +111,12 @@ def _set_params( sparsity=None, metrics_kwargs=None, include_multi_channel_metrics=False, + delete_existing_metrics=False, **other_kwargs, ): + import pandas as pd + # TODO alessio can you check this : this used to be in the function but now we have ComputeTemplateMetrics.function_factory() if include_multi_channel_metrics or ( metric_names is not None and any([m in get_multi_channel_template_metric_names() for m in metric_names]) @@ -139,12 +144,36 @@ def _set_params( metrics_kwargs_ = _default_function_kwargs.copy() metrics_kwargs_.update(metrics_kwargs) + metrics_to_compute = metric_names + tm_extension = self.sorting_analyzer.get_extension("template_metrics") + if delete_existing_metrics is False and tm_extension is not None: + + existing_params = tm_extension.params["metrics_kwargs"] + # checks that existing metrics were calculated using the same params + if existing_params != metrics_kwargs_: + warnings.warn( + f"The parameters used to calculate the previous template metrics are different" + f"than those used now.\nPrevious parameters: {existing_params}\nCurrent " + f"parameters: {metrics_kwargs_}\nDeleting previous template metrics..." + ) + tm_extension.params["metric_names"] = [] + existing_metric_names = [] + else: + existing_metric_names = tm_extension.params["metric_names"] + + existing_metric_names_propogated = [ + metric_name for metric_name in existing_metric_names if metric_name not in metrics_to_compute + ] + metric_names = metrics_to_compute + existing_metric_names_propogated + params = dict( - metric_names=[str(name) for name in np.unique(metric_names)], + metric_names=metric_names, sparsity=sparsity, peak_sign=peak_sign, upsampling_factor=int(upsampling_factor), metrics_kwargs=metrics_kwargs_, + delete_existing_metrics=delete_existing_metrics, + metrics_to_compute=metrics_to_compute, ) return params @@ -158,6 +187,7 @@ def _merge_extension_data( ): import pandas as pd + metric_names = self.params["metric_names"] old_metrics = self.data["metrics"] all_unit_ids = new_sorting_analyzer.unit_ids @@ -166,19 +196,20 @@ def _merge_extension_data( metrics = pd.DataFrame(index=all_unit_ids, columns=old_metrics.columns) metrics.loc[not_new_ids, :] = old_metrics.loc[not_new_ids, :] - metrics.loc[new_unit_ids, :] = self._compute_metrics(new_sorting_analyzer, new_unit_ids, verbose, **job_kwargs) + metrics.loc[new_unit_ids, :] = self._compute_metrics( + new_sorting_analyzer, new_unit_ids, verbose, metric_names, **job_kwargs + ) new_data = dict(metrics=metrics) return new_data - def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job_kwargs): + def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metric_names=None, **job_kwargs): """ Compute template metrics. """ import pandas as pd from scipy.signal import resample_poly - metric_names = self.params["metric_names"] sparsity = self.params["sparsity"] peak_sign = self.params["peak_sign"] upsampling_factor = self.params["upsampling_factor"] @@ -290,10 +321,30 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job return template_metrics def _run(self, verbose=False): - self.data["metrics"] = self._compute_metrics( - sorting_analyzer=self.sorting_analyzer, unit_ids=None, verbose=verbose + + delete_existing_metrics = self.params["delete_existing_metrics"] + metrics_to_compute = self.params["metrics_to_compute"] + + # compute the metrics which have been specified by the user + computed_metrics = self._compute_metrics( + sorting_analyzer=self.sorting_analyzer, unit_ids=None, verbose=verbose, metric_names=metrics_to_compute ) + existing_metrics = [] + tm_extension = self.sorting_analyzer.get_extension("template_metrics") + if ( + delete_existing_metrics is False + and tm_extension is not None + and tm_extension.data.get("metrics") is not None + ): + existing_metrics = tm_extension.params["metric_names"] + + # append the metrics which were previously computed + for metric_name in set(existing_metrics).difference(metrics_to_compute): + computed_metrics[metric_name] = tm_extension.data["metrics"][metric_name] + + self.data["metrics"] = computed_metrics + def _get_data(self): return self.data["metrics"] diff --git a/src/spikeinterface/postprocessing/tests/conftest.py b/src/spikeinterface/postprocessing/tests/conftest.py new file mode 100644 index 0000000000..51ac8aa250 --- /dev/null +++ b/src/spikeinterface/postprocessing/tests/conftest.py @@ -0,0 +1,33 @@ +import pytest + +from spikeinterface.core import ( + generate_ground_truth_recording, + create_sorting_analyzer, +) + + +def _small_sorting_analyzer(): + recording, sorting = generate_ground_truth_recording( + durations=[2.0], + num_units=10, + seed=1205, + ) + + sorting_analyzer = create_sorting_analyzer(recording=recording, sorting=sorting, format="memory") + + extensions_to_compute = { + "random_spikes": {"seed": 1205}, + "noise_levels": {"seed": 1205}, + "waveforms": {}, + "templates": {"operators": ["average", "median"]}, + "spike_amplitudes": {}, + } + + sorting_analyzer.compute(extensions_to_compute) + + return sorting_analyzer + + +@pytest.fixture(scope="module") +def small_sorting_analyzer(): + return _small_sorting_analyzer() diff --git a/src/spikeinterface/postprocessing/tests/test_template_metrics.py b/src/spikeinterface/postprocessing/tests/test_template_metrics.py index 694aa083cc..5056d4ff2a 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_metrics.py +++ b/src/spikeinterface/postprocessing/tests/test_template_metrics.py @@ -1,6 +1,108 @@ from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite from spikeinterface.postprocessing import ComputeTemplateMetrics import pytest +import csv + +from spikeinterface.postprocessing.template_metrics import _single_channel_metric_name_to_func + +template_metrics = list(_single_channel_metric_name_to_func.keys()) + + +def test_compute_new_template_metrics(small_sorting_analyzer): + """ + Computes template metrics then computes a subset of template metrics, and checks + that the old template metrics are not deleted. + + Then computes template metrics with new parameters and checks that old metrics + are deleted. + """ + + # calculate just exp_decay + small_sorting_analyzer.compute({"template_metrics": {"metric_names": ["exp_decay"]}}) + template_metric_extension = small_sorting_analyzer.get_extension("template_metrics") + + assert "exp_decay" in list(template_metric_extension.get_data().keys()) + assert "half_width" not in list(template_metric_extension.get_data().keys()) + + # calculate all template metrics + small_sorting_analyzer.compute("template_metrics") + # calculate just exp_decay - this should not delete any other metrics + small_sorting_analyzer.compute({"template_metrics": {"metric_names": ["exp_decay"]}}) + template_metric_extension = small_sorting_analyzer.get_extension("template_metrics") + + set(template_metrics) == set(template_metric_extension.get_data().keys()) + + # calculate just exp_decay with delete_existing_metrics + small_sorting_analyzer.compute( + {"template_metrics": {"metric_names": ["exp_decay"], "delete_existing_metrics": True}} + ) + template_metric_extension = small_sorting_analyzer.get_extension("template_metrics") + computed_metric_names = template_metric_extension.get_data().keys() + + for metric_name in template_metrics: + if metric_name == "exp_decay": + assert metric_name in computed_metric_names + else: + assert metric_name not in computed_metric_names + + # check that, when parameters are changed, the old metrics are deleted + small_sorting_analyzer.compute( + {"template_metrics": {"metric_names": ["exp_decay"], "metrics_kwargs": {"recovery_window_ms": 0.6}}} + ) + + +def test_metric_names_in_same_order(small_sorting_analyzer): + """ + Computes sepecified template metrics and checks order is propogated. + """ + specified_metric_names = ["peak_trough_ratio", "num_negative_peaks", "half_width"] + small_sorting_analyzer.compute("template_metrics", metric_names=specified_metric_names) + tm_keys = small_sorting_analyzer.get_extension("template_metrics").get_data().keys() + for i in range(3): + assert specified_metric_names[i] == tm_keys[i] + + +def test_save_template_metrics(small_sorting_analyzer, create_cache_folder): + """ + Computes template metrics in binary folder format. Then computes subsets of template + metrics and checks if they are saved correctly. + """ + + small_sorting_analyzer.compute("template_metrics") + + cache_folder = create_cache_folder + output_folder = cache_folder / "sorting_analyzer" + + folder_analyzer = small_sorting_analyzer.save_as(format="binary_folder", folder=output_folder) + template_metrics_filename = output_folder / "extensions" / "template_metrics" / "metrics.csv" + + with open(template_metrics_filename) as metrics_file: + saved_metrics = csv.reader(metrics_file) + metric_names = next(saved_metrics) + + for metric_name in template_metrics: + assert metric_name in metric_names + + folder_analyzer.compute("template_metrics", metric_names=["half_width"], delete_existing_metrics=False) + + with open(template_metrics_filename) as metrics_file: + saved_metrics = csv.reader(metrics_file) + metric_names = next(saved_metrics) + + for metric_name in template_metrics: + assert metric_name in metric_names + + folder_analyzer.compute("template_metrics", metric_names=["half_width"], delete_existing_metrics=True) + + with open(template_metrics_filename) as metrics_file: + saved_metrics = csv.reader(metrics_file) + metric_names = next(saved_metrics) + + for metric_name in template_metrics: + if metric_name == "half_width": + assert metric_name in metric_names + else: + assert metric_name not in metric_names class TestTemplateMetrics(AnalyzerExtensionCommonTestSuite): diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index 2de31ad750..8dfd41cf88 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -69,6 +69,9 @@ def compute_num_spikes(sorting_analyzer, unit_ids=None, **kwargs): return num_spikes +_default_params["num_spikes"] = {} + + def compute_firing_rates(sorting_analyzer, unit_ids=None): """ Compute the firing rate across segments. @@ -98,6 +101,9 @@ def compute_firing_rates(sorting_analyzer, unit_ids=None): return firing_rates +_default_params["firing_rate"] = {} + + def compute_presence_ratios(sorting_analyzer, bin_duration_s=60.0, mean_fr_ratio_thresh=0.0, unit_ids=None): """ Calculate the presence ratio, the fraction of time the unit is firing above a certain threshold. @@ -1550,3 +1556,10 @@ def compute_sd_ratio( sd_ratio[unit_id] = unit_std / std_noise return sd_ratio + + +_default_params["sd_ratio"] = dict( + censored_period_ms=4.0, + correct_for_drift=True, + correct_for_template_itself=True, +) diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index cdf6151e95..52eb56c4ee 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -2,7 +2,6 @@ from __future__ import annotations - import warnings from copy import deepcopy @@ -12,7 +11,12 @@ from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension -from .quality_metric_list import compute_pc_metrics, _misc_metric_name_to_func, _possible_pc_metric_names +from .quality_metric_list import ( + compute_pc_metrics, + _misc_metric_name_to_func, + _possible_pc_metric_names, + compute_name_to_column_names, +) from .misc_metrics import _default_params as misc_metrics_params from .pca_metrics import _default_params as pca_metrics_params @@ -30,8 +34,10 @@ class ComputeQualityMetrics(AnalyzerExtension): qm_params : dict or None Dictionary with parameters for quality metrics calculation. Default parameters can be obtained with: `si.qualitymetrics.get_default_qm_params()` - skip_pc_metrics : bool + skip_pc_metrics : bool, default: False If True, PC metrics computation is skipped. + delete_existing_metrics : bool, default: False + If True, any quality metrics attached to the `sorting_analyzer` are deleted. If False, any metrics which were previously calculated but are not included in `metric_names` are kept. Returns ------- @@ -49,7 +55,17 @@ class ComputeQualityMetrics(AnalyzerExtension): use_nodepipeline = False need_job_kwargs = True - def _set_params(self, metric_names=None, qm_params=None, peak_sign=None, seed=None, skip_pc_metrics=False): + def _set_params( + self, + metric_names=None, + qm_params=None, + peak_sign=None, + seed=None, + skip_pc_metrics=False, + delete_existing_metrics=False, + metrics_to_compute=None, + ): + if metric_names is None: metric_names = list(_misc_metric_name_to_func.keys()) # if PC is available, PC metrics are automatically added to the list @@ -71,12 +87,24 @@ def _set_params(self, metric_names=None, qm_params=None, peak_sign=None, seed=No if "peak_sign" in qm_params_[k] and peak_sign is not None: qm_params_[k]["peak_sign"] = peak_sign + metrics_to_compute = metric_names + qm_extension = self.sorting_analyzer.get_extension("quality_metrics") + if delete_existing_metrics is False and qm_extension is not None: + + existing_metric_names = qm_extension.params["metric_names"] + existing_metric_names_propogated = [ + metric_name for metric_name in existing_metric_names if metric_name not in metrics_to_compute + ] + metric_names = metrics_to_compute + existing_metric_names_propogated + params = dict( - metric_names=[str(name) for name in np.unique(metric_names)], + metric_names=metric_names, peak_sign=peak_sign, seed=seed, qm_params=qm_params_, skip_pc_metrics=skip_pc_metrics, + delete_existing_metrics=delete_existing_metrics, + metrics_to_compute=metrics_to_compute, ) return params @@ -91,6 +119,7 @@ def _merge_extension_data( ): import pandas as pd + metric_names = self.params["metric_names"] old_metrics = self.data["metrics"] all_unit_ids = new_sorting_analyzer.unit_ids @@ -99,16 +128,18 @@ def _merge_extension_data( metrics = pd.DataFrame(index=all_unit_ids, columns=old_metrics.columns) metrics.loc[not_new_ids, :] = old_metrics.loc[not_new_ids, :] - metrics.loc[new_unit_ids, :] = self._compute_metrics(new_sorting_analyzer, new_unit_ids, verbose, **job_kwargs) + metrics.loc[new_unit_ids, :] = self._compute_metrics( + new_sorting_analyzer, new_unit_ids, verbose, metric_names, **job_kwargs + ) new_data = dict(metrics=metrics) return new_data - def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job_kwargs): + def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metric_names=None, **job_kwargs): """ Compute quality metrics. """ - metric_names = self.params["metric_names"] + qm_params = self.params["qm_params"] # sparsity = self.params["sparsity"] seed = self.params["seed"] @@ -188,10 +219,35 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job return metrics def _run(self, verbose=False, **job_kwargs): - self.data["metrics"] = self._compute_metrics( - sorting_analyzer=self.sorting_analyzer, unit_ids=None, verbose=verbose, **job_kwargs + + metrics_to_compute = self.params["metrics_to_compute"] + delete_existing_metrics = self.params["delete_existing_metrics"] + + computed_metrics = self._compute_metrics( + sorting_analyzer=self.sorting_analyzer, + unit_ids=None, + verbose=verbose, + metric_names=metrics_to_compute, + **job_kwargs, ) + existing_metrics = [] + qm_extension = self.sorting_analyzer.get_extension("quality_metrics") + if ( + delete_existing_metrics is False + and qm_extension is not None + and qm_extension.data.get("metrics") is not None + ): + existing_metrics = qm_extension.params["metric_names"] + + # append the metrics which were previously computed + for metric_name in set(existing_metrics).difference(metrics_to_compute): + # some metrics names produce data columns with other names. This deals with that. + for column_name in compute_name_to_column_names[metric_name]: + computed_metrics[column_name] = qm_extension.data["metrics"][column_name] + + self.data["metrics"] = computed_metrics + def _get_data(self): return self.data["metrics"] diff --git a/src/spikeinterface/qualitymetrics/quality_metric_list.py b/src/spikeinterface/qualitymetrics/quality_metric_list.py index 140ad87a8b..375dd320ae 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_list.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_list.py @@ -53,3 +53,29 @@ "drift": compute_drift_metrics, "sd_ratio": compute_sd_ratio, } + +# a dict converting the name of the metric for computation to the output of that computation +compute_name_to_column_names = { + "num_spikes": ["num_spikes"], + "firing_rate": ["firing_rate"], + "presence_ratio": ["presence_ratio"], + "snr": ["snr"], + "isi_violation": ["isi_violations_ratio", "isi_violations_count"], + "rp_violation": ["rp_violations", "rp_contamination"], + "sliding_rp_violation": ["sliding_rp_violation"], + "amplitude_cutoff": ["amplitude_cutoff"], + "amplitude_median": ["amplitude_median"], + "amplitude_cv": ["amplitude_cv_median", "amplitude_cv_range"], + "synchrony": ["sync_spike_2", "sync_spike_4", "sync_spike_8"], + "firing_range": ["firing_range"], + "drift": ["drift_ptp", "drift_std", "drift_mad"], + "sd_ratio": ["sd_ratio"], + "isolation_distance": ["isolation_distance"], + "l_ratio": ["l_ratio"], + "d_prime": ["d_prime"], + "nearest_neighbor": ["nn_hit_rate", "nn_miss_rate"], + "nn_isolation": ["nn_isolation", "nn_unit_id"], + "nn_noise_overlap": ["nn_noise_overlap"], + "silhouette": ["silhouette"], + "silhouette_full": ["silhouette_full"], +} diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index e7fc7ce209..4c0890b62b 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -1,6 +1,8 @@ import pytest from pathlib import Path import numpy as np +from copy import deepcopy +import csv from spikeinterface.core import ( NumpySorting, synthetize_spike_train_bad_isi, @@ -41,12 +43,167 @@ compute_quality_metrics, ) + from spikeinterface.core.basesorting import minimum_spike_dtype job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") +def test_compute_new_quality_metrics(small_sorting_analyzer): + """ + Computes quality metrics then computes a subset of quality metrics, and checks + that the old quality metrics are not deleted. + """ + + qm_params = { + "presence_ratio": {"bin_duration_s": 0.1}, + "amplitude_cutoff": {"num_histogram_bins": 3}, + "firing_range": {"bin_size_s": 1}, + } + + small_sorting_analyzer.compute({"quality_metrics": {"metric_names": ["snr"]}}) + qm_extension = small_sorting_analyzer.get_extension("quality_metrics") + calculated_metrics = list(qm_extension.get_data().keys()) + + assert calculated_metrics == ["snr"] + + small_sorting_analyzer.compute( + {"quality_metrics": {"metric_names": list(qm_params.keys()), "qm_params": qm_params}} + ) + small_sorting_analyzer.compute({"quality_metrics": {"metric_names": ["snr"]}}) + + quality_metric_extension = small_sorting_analyzer.get_extension("quality_metrics") + + # Check old metrics are not deleted and the new one is added to the data and metadata + assert set(list(quality_metric_extension.get_data().keys())) == set( + [ + "amplitude_cutoff", + "firing_range", + "presence_ratio", + "snr", + ] + ) + assert set(list(quality_metric_extension.params.get("metric_names"))) == set( + [ + "amplitude_cutoff", + "firing_range", + "presence_ratio", + "snr", + ] + ) + + # check that, when parameters are changed, the data and metadata are updated + old_snr_data = deepcopy(quality_metric_extension.get_data()["snr"].values) + small_sorting_analyzer.compute( + {"quality_metrics": {"metric_names": ["snr"], "qm_params": {"snr": {"peak_mode": "peak_to_peak"}}}} + ) + new_quality_metric_extension = small_sorting_analyzer.get_extension("quality_metrics") + new_snr_data = new_quality_metric_extension.get_data()["snr"].values + + assert np.all(old_snr_data != new_snr_data) + assert new_quality_metric_extension.params["qm_params"]["snr"]["peak_mode"] == "peak_to_peak" + + # check that all quality metrics are deleted when parents are recomputed, even after + # recomputation + extensions_to_compute = { + "templates": {"operators": ["average", "median"]}, + "spike_amplitudes": {}, + "spike_locations": {}, + "principal_components": {}, + } + + small_sorting_analyzer.compute(extensions_to_compute) + + assert small_sorting_analyzer.get_extension("quality_metrics") is None + + +def test_metric_names_in_same_order(small_sorting_analyzer): + """ + Computes sepecified quality metrics and checks order is propogated. + """ + specified_metric_names = ["firing_range", "snr", "amplitude_cutoff"] + small_sorting_analyzer.compute("quality_metrics", metric_names=specified_metric_names) + qm_keys = small_sorting_analyzer.get_extension("quality_metrics").get_data().keys() + for i in range(3): + assert specified_metric_names[i] == qm_keys[i] + + +def test_save_quality_metrics(small_sorting_analyzer, create_cache_folder): + """ + Computes quality metrics in binary folder format. Then computes subsets of quality + metrics and checks if they are saved correctly. + """ + + # can't use _misc_metric_name_to_func as some functions compute several qms + # e.g. isi_violation and synchrony + quality_metrics = [ + "num_spikes", + "firing_rate", + "presence_ratio", + "snr", + "isi_violations_ratio", + "isi_violations_count", + "rp_contamination", + "rp_violations", + "sliding_rp_violation", + "amplitude_cutoff", + "amplitude_median", + "amplitude_cv_median", + "amplitude_cv_range", + "sync_spike_2", + "sync_spike_4", + "sync_spike_8", + "firing_range", + "drift_ptp", + "drift_std", + "drift_mad", + "sd_ratio", + "isolation_distance", + "l_ratio", + "d_prime", + "silhouette", + "nn_hit_rate", + "nn_miss_rate", + ] + + small_sorting_analyzer.compute("quality_metrics") + + cache_folder = create_cache_folder + output_folder = cache_folder / "sorting_analyzer" + + folder_analyzer = small_sorting_analyzer.save_as(format="binary_folder", folder=output_folder) + quality_metrics_filename = output_folder / "extensions" / "quality_metrics" / "metrics.csv" + + with open(quality_metrics_filename) as metrics_file: + saved_metrics = csv.reader(metrics_file) + metric_names = next(saved_metrics) + + for metric_name in quality_metrics: + assert metric_name in metric_names + + folder_analyzer.compute("quality_metrics", metric_names=["snr"], delete_existing_metrics=False) + + with open(quality_metrics_filename) as metrics_file: + saved_metrics = csv.reader(metrics_file) + metric_names = next(saved_metrics) + + for metric_name in quality_metrics: + assert metric_name in metric_names + + folder_analyzer.compute("quality_metrics", metric_names=["snr"], delete_existing_metrics=True) + + with open(quality_metrics_filename) as metrics_file: + saved_metrics = csv.reader(metrics_file) + metric_names = next(saved_metrics) + + for metric_name in quality_metrics: + if metric_name == "snr": + assert metric_name in metric_names + else: + assert metric_name not in metric_names + + def test_unit_structure_in_output(small_sorting_analyzer): qm_params = { @@ -129,10 +286,10 @@ def test_unit_id_order_independence(small_sorting_analyzer): small_sorting_analyzer_2, metric_names=get_quality_metric_list(), qm_params=qm_params ) - for metric, metric_1_data in quality_metrics_1.items(): - assert quality_metrics_2[metric][2] == metric_1_data["#3"] - assert quality_metrics_2[metric][7] == metric_1_data["#9"] - assert quality_metrics_2[metric][1] == metric_1_data["#4"] + for metric, metric_2_data in quality_metrics_2.items(): + assert quality_metrics_1[metric]["#3"] == metric_2_data[2] + assert quality_metrics_1[metric]["#9"] == metric_2_data[7] + assert quality_metrics_1[metric]["#4"] == metric_2_data[1] def _sorting_violation():