From dbbb1b80a69558e0f27097e8ec5df3dc820ae3d5 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Thu, 7 Nov 2024 14:42:36 +0000 Subject: [PATCH] Keep inconsistent params by metric info, and warn/error based on this --- .../curation/model_based_curation.py | 56 +++++++++-------- .../curation/train_manual_curation.py | 62 +++++++++---------- 2 files changed, 59 insertions(+), 59 deletions(-) diff --git a/src/spikeinterface/curation/model_based_curation.py b/src/spikeinterface/curation/model_based_curation.py index 03fba1c14a..d3ed0ed0bd 100644 --- a/src/spikeinterface/curation/model_based_curation.py +++ b/src/spikeinterface/curation/model_based_curation.py @@ -44,7 +44,7 @@ def __init__(self, sorting_analyzer: SortingAnalyzer, pipeline): self.required_metrics = pipeline.feature_names_in_ def predict_labels( - self, label_conversion=None, input_data=None, export_to_phy=False, model_info=None, enforce_params=False + self, label_conversion=None, input_data=None, export_to_phy=False, model_info=None, enforce_metric_params=False ): """ Predicts the labels for the spike sorting data using the trained model. @@ -61,7 +61,7 @@ def predict_labels( The input data for classification. If not provided, the method will extract metrics stored in the sorting analyzer. export_to_phy : bool, default: False. Whether to export the classified units to Phy format. Default is False. - enforce_params : bool, default: False + enforce_metric_params : bool, default: False If True and the parameters used to compute the metrics in `sorting_analyzer` are different than the parmeters used to compute the metrics used to train the model, this function will raise an error. Otherwise, a warning is raised. @@ -81,7 +81,7 @@ def predict_labels( raise ValueError("Input data must be a pandas DataFrame") if model_info is not None: - self._check_params_for_classification(enforce_params, model_info=model_info) + self._check_params_for_classification(enforce_metric_params, model_info=model_info) if model_info is not None and label_conversion is None: try: @@ -151,13 +151,13 @@ def _get_metrics_for_classification(self): return input_data - def _check_params_for_classification(self, enforce_params=False, model_info=None): + def _check_params_for_classification(self, enforce_metric_params=False, model_info=None): """ Check that quality and template metrics parameters match those used to train the model Parameters ---------- - enforce_params : bool, default: False + enforce_metric_params : bool, default: False If True and the parameters used to compute the metrics in `sorting_analyzer` are different than the parmeters used to compute the metrics used to train the model, this function will raise an error. Otherwise, a warning is raised. model_info_path : str or Path, default: None @@ -169,20 +169,26 @@ def _check_params_for_classification(self, enforce_params=False, model_info=None if quality_metrics_extension is not None: - model_quality_metrics_params = model_info["metric_params"]["quality_metric_params"]["qm_params"] + model_quality_metrics_params = model_info["metric_params"]["quality_metric_params"] quality_metrics_params = quality_metrics_extension.params["qm_params"] - if model_quality_metrics_params == []: - warning_message = "Parameters used to compute quality metrics used to train this model are unknown." - if enforce_params is True: - raise Exception(warning_message) - else: - warnings.warn(warning_message) - - # need to make sure both dicts are in json format, so that lists are equal - if json.dumps(quality_metrics_params) != json.dumps(model_quality_metrics_params): - warning_message = "Quality metrics params do not match those used to train model. Check these in the 'model_info.json' file." - if enforce_params is True: + inconsistent_metrics = [] + for metric in model_quality_metrics_params["metric_names"]: + if metric in model_quality_metrics_params["qm_params"]: + inconsistent_metrics += metric + + if quality_metrics_params[metric] != model_quality_metrics_params["qm_params"][metric]: + warning_message = "Quality metric params for {metric} do not match those used to train the model. Parameters can be found in the 'model_info.json' file." + if enforce_metric_params is True: + raise Exception(warning_message) + else: + warnings.warn(warning_message) + + if len(inconsistent_metrics) > 0: + warning_message = ( + "Parameters used to compute metrics {inconsistent_metrics}, used to train this model, are unknown." + ) + if enforce_metric_params is True: raise Exception(warning_message) else: warnings.warn(warning_message) @@ -192,16 +198,16 @@ def _check_params_for_classification(self, enforce_params=False, model_info=None model_template_metrics_params = model_info["metric_params"]["template_metric_params"]["metrics_kwargs"] template_metrics_params = template_metrics_extension.params["metrics_kwargs"] - if template_metrics_params == []: - warning_message = "Parameters used to compute template metrics used to train this model are unknown." - if enforce_params is True: + if template_metrics_params == {}: + warning_message = "Parameters used to compute template metrics, used to train this model, are unknown." + if enforce_metric_params is True: raise Exception(warning_message) else: warnings.warn(warning_message) if template_metrics_params != model_template_metrics_params: - warning_message = "Template metrics metrics params do not match those used to train model. Check these in the 'model_info.json' file." - if enforce_params is True: + warning_message = "Template metrics params do not match those used to train model. Parameters can be found in the 'model_info.json' file." + if enforce_metric_params is True: raise Exception(warning_message) else: warnings.warn(warning_message) @@ -234,7 +240,7 @@ def auto_label_units( trust_model=False, trusted=None, export_to_phy=False, - enforce_params=False, + enforce_metric_params=False, ): """ Automatically labels units based on a model-based classification, either from a model @@ -264,7 +270,7 @@ def auto_label_units( automatically inferred. If False, the `trusted` parameter must be provided to indicate the trusted objects. trusted : list of str, default: None Passed to skops.load. The object will be loaded only if there are only trusted objects and objects of types listed in trusted in the dumped file. - enforce_params : bool, default: False + enforce_metric_params : bool, default: False If True and the parameters used to compute the metrics in `sorting_analyzer` are different than the parmeters used to compute the metrics used to train the model, this function will raise an error. Otherwise, a warning is raised. @@ -296,7 +302,7 @@ def auto_label_units( label_conversion=label_conversion, export_to_phy=export_to_phy, model_info=model_info, - enforce_params=enforce_params, + enforce_metric_params=enforce_metric_params, ) return classified_units diff --git a/src/spikeinterface/curation/train_manual_curation.py b/src/spikeinterface/curation/train_manual_curation.py index 2f420d8734..d4a73ccec1 100644 --- a/src/spikeinterface/curation/train_manual_curation.py +++ b/src/spikeinterface/curation/train_manual_curation.py @@ -229,41 +229,33 @@ def load_and_preprocess_analyzers(self, analyzers, enforce_metric_params): warnings.warn("No metric_names provided, using all metrics calculated by the analyzers") self.metric_names = self.testing_metrics.columns.tolist() - consistent_params = self._check_metrics_parameters(analyzers, enforce_metric_params) + conflicting_metrics = self._check_metrics_parameters(analyzers, enforce_metric_params) + print(conflicting_metrics) - # Tidies up the metrics_params. If the metrics parameters are consistent, we keep one copy self.metrics_params = {} - if consistent_params is True: - if analyzers[0].has_extension("quality_metrics") is True: - self.metrics_params["quality_metric_params"] = ( - analyzers[0].extensions["quality_metrics"].params["qm_params"] - ) - if analyzers[0].has_extension("template_metrics") is True: - self.metrics_params["template_metric_params"] = ( - analyzers[0].extensions["template_metrics"].params["qm_params"] - ) - # If they are not, we only save the metric names - else: - metrics_params = {} - if analyzers[0].has_extension("quality_metrics") is True: - self.metrics_params["quality_metric_params"] = {} - self.metrics_params["quality_metric_params"]["metric_names"] = ( - analyzers[0].extensions["quality_metrics"].params["qm_params"]["metric_names"] - ) - if analyzers[0].has_extension("template_metrics") is True: - self.metrics_params["template_metric_params"] = {} - self.metrics_params["template_metric_params"]["metric_names"] = ( - analyzers[0].extensions["template_metrics"].params["metric_names"] - ) - - self.metrics_params = metrics_params + if analyzers[0].has_extension("quality_metrics") is True: + self.metrics_params["quality_metric_params"] = analyzers[0].extensions["quality_metrics"].params + # remove metrics with conflicting params + if len(conflicting_metrics) > 0: + qm_names = self.metrics_params["quality_metric_params"]["metric_names"] + consistent_metrics = list(set(qm_names).difference(set(conflicting_metrics))) + consistent_metric_params = { + metric: analyzers[0].extensions["quality_metrics"].params["qm_params"][metric] + for metric in consistent_metrics + } + self.metrics_params["quality_metric_params"]["qm_params"] = consistent_metric_params + + if analyzers[0].has_extension("template_metrics") is True: + self.metrics_params["template_metric_params"] = analyzers[0].extensions["template_metrics"].params + if "template_metrics" in conflicting_metrics: + self.metrics_params["template_metric_params"] = analyzers[0].extensions["template_metrics"].params = {} self.process_test_data_for_classification() def _check_metrics_parameters(self, analyzers, enforce_metric_params): """Checks that the metrics of each analyzer have been calcualted using the same parameters""" - consistent_params = True + conflicting_metrics = [] for analyzer_index_1, analyzer_1 in enumerate(analyzers): for analyzer_index_2, analyzer_2 in enumerate(analyzers): @@ -279,25 +271,27 @@ def _check_metrics_parameters(self, analyzers, enforce_metric_params): if analyzer_2.has_extension("template_metrics") is True: tm_params_2 = analyzer_2.extensions["template_metrics"].params["metrics_kwargs"] - conflicting_metrics = [] + conflicting_metrics_between_1_2 = [] # check quality metrics params for metric, params_1 in qm_params_1.items(): if params_1 != qm_params_2.get(metric): - conflicting_metrics.append(metric) + conflicting_metrics_between_1_2.append(metric) # check template metric params for metric, params_1 in tm_params_1.items(): if params_1 != tm_params_2.get(metric): - conflicting_metrics.append("template_metrics") + conflicting_metrics_between_1_2.append("template_metrics") + + conflicting_metrics += conflicting_metrics_between_1_2 - if len(conflicting_metrics) > 0: - warning_message = f"Parameters used to calculate {conflicting_metrics} are different for sorting_analyzers #{analyzer_index_1} and #{analyzer_index_2}" + if len(conflicting_metrics_between_1_2) > 0: + warning_message = f"Parameters used to calculate {conflicting_metrics_between_1_2} are different for sorting_analyzers #{analyzer_index_1} and #{analyzer_index_2}" if enforce_metric_params is True: raise Exception(warning_message) else: warnings.warn(warning_message) - consistent_params = False - return consistent_params + unique_conflicting_metrics = set(conflicting_metrics) + return unique_conflicting_metrics def load_and_preprocess_csv(self, paths): self._load_data_files(paths)