Skip to content

Commit

Permalink
Keep inconsistent params by metric info, and warn/error based on this
Browse files Browse the repository at this point in the history
  • Loading branch information
chrishalcrow committed Nov 7, 2024
1 parent 9488416 commit dbbb1b8
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 59 deletions.
56 changes: 31 additions & 25 deletions src/spikeinterface/curation/model_based_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(self, sorting_analyzer: SortingAnalyzer, pipeline):
self.required_metrics = pipeline.feature_names_in_

def predict_labels(
self, label_conversion=None, input_data=None, export_to_phy=False, model_info=None, enforce_params=False
self, label_conversion=None, input_data=None, export_to_phy=False, model_info=None, enforce_metric_params=False
):
"""
Predicts the labels for the spike sorting data using the trained model.
Expand All @@ -61,7 +61,7 @@ def predict_labels(
The input data for classification. If not provided, the method will extract metrics stored in the sorting analyzer.
export_to_phy : bool, default: False.
Whether to export the classified units to Phy format. Default is False.
enforce_params : bool, default: False
enforce_metric_params : bool, default: False
If True and the parameters used to compute the metrics in `sorting_analyzer` are different than the parmeters
used to compute the metrics used to train the model, this function will raise an error. Otherwise, a warning is raised.
Expand All @@ -81,7 +81,7 @@ def predict_labels(
raise ValueError("Input data must be a pandas DataFrame")

if model_info is not None:
self._check_params_for_classification(enforce_params, model_info=model_info)
self._check_params_for_classification(enforce_metric_params, model_info=model_info)

if model_info is not None and label_conversion is None:
try:
Expand Down Expand Up @@ -151,13 +151,13 @@ def _get_metrics_for_classification(self):

return input_data

def _check_params_for_classification(self, enforce_params=False, model_info=None):
def _check_params_for_classification(self, enforce_metric_params=False, model_info=None):
"""
Check that quality and template metrics parameters match those used to train the model
Parameters
----------
enforce_params : bool, default: False
enforce_metric_params : bool, default: False
If True and the parameters used to compute the metrics in `sorting_analyzer` are different than the parmeters
used to compute the metrics used to train the model, this function will raise an error. Otherwise, a warning is raised.
model_info_path : str or Path, default: None
Expand All @@ -169,20 +169,26 @@ def _check_params_for_classification(self, enforce_params=False, model_info=None

if quality_metrics_extension is not None:

model_quality_metrics_params = model_info["metric_params"]["quality_metric_params"]["qm_params"]
model_quality_metrics_params = model_info["metric_params"]["quality_metric_params"]
quality_metrics_params = quality_metrics_extension.params["qm_params"]

if model_quality_metrics_params == []:
warning_message = "Parameters used to compute quality metrics used to train this model are unknown."
if enforce_params is True:
raise Exception(warning_message)
else:
warnings.warn(warning_message)

# need to make sure both dicts are in json format, so that lists are equal
if json.dumps(quality_metrics_params) != json.dumps(model_quality_metrics_params):
warning_message = "Quality metrics params do not match those used to train model. Check these in the 'model_info.json' file."
if enforce_params is True:
inconsistent_metrics = []
for metric in model_quality_metrics_params["metric_names"]:
if metric in model_quality_metrics_params["qm_params"]:
inconsistent_metrics += metric

if quality_metrics_params[metric] != model_quality_metrics_params["qm_params"][metric]:
warning_message = "Quality metric params for {metric} do not match those used to train the model. Parameters can be found in the 'model_info.json' file."
if enforce_metric_params is True:
raise Exception(warning_message)
else:
warnings.warn(warning_message)

if len(inconsistent_metrics) > 0:
warning_message = (
"Parameters used to compute metrics {inconsistent_metrics}, used to train this model, are unknown."
)
if enforce_metric_params is True:
raise Exception(warning_message)
else:
warnings.warn(warning_message)
Expand All @@ -192,16 +198,16 @@ def _check_params_for_classification(self, enforce_params=False, model_info=None
model_template_metrics_params = model_info["metric_params"]["template_metric_params"]["metrics_kwargs"]
template_metrics_params = template_metrics_extension.params["metrics_kwargs"]

if template_metrics_params == []:
warning_message = "Parameters used to compute template metrics used to train this model are unknown."
if enforce_params is True:
if template_metrics_params == {}:
warning_message = "Parameters used to compute template metrics, used to train this model, are unknown."
if enforce_metric_params is True:
raise Exception(warning_message)
else:
warnings.warn(warning_message)

if template_metrics_params != model_template_metrics_params:
warning_message = "Template metrics metrics params do not match those used to train model. Check these in the 'model_info.json' file."
if enforce_params is True:
warning_message = "Template metrics params do not match those used to train model. Parameters can be found in the 'model_info.json' file."
if enforce_metric_params is True:
raise Exception(warning_message)
else:
warnings.warn(warning_message)
Expand Down Expand Up @@ -234,7 +240,7 @@ def auto_label_units(
trust_model=False,
trusted=None,
export_to_phy=False,
enforce_params=False,
enforce_metric_params=False,
):
"""
Automatically labels units based on a model-based classification, either from a model
Expand Down Expand Up @@ -264,7 +270,7 @@ def auto_label_units(
automatically inferred. If False, the `trusted` parameter must be provided to indicate the trusted objects.
trusted : list of str, default: None
Passed to skops.load. The object will be loaded only if there are only trusted objects and objects of types listed in trusted in the dumped file.
enforce_params : bool, default: False
enforce_metric_params : bool, default: False
If True and the parameters used to compute the metrics in `sorting_analyzer` are different than the parmeters
used to compute the metrics used to train the model, this function will raise an error. Otherwise, a warning is raised.
Expand Down Expand Up @@ -296,7 +302,7 @@ def auto_label_units(
label_conversion=label_conversion,
export_to_phy=export_to_phy,
model_info=model_info,
enforce_params=enforce_params,
enforce_metric_params=enforce_metric_params,
)

return classified_units
Expand Down
62 changes: 28 additions & 34 deletions src/spikeinterface/curation/train_manual_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,41 +229,33 @@ def load_and_preprocess_analyzers(self, analyzers, enforce_metric_params):
warnings.warn("No metric_names provided, using all metrics calculated by the analyzers")
self.metric_names = self.testing_metrics.columns.tolist()

consistent_params = self._check_metrics_parameters(analyzers, enforce_metric_params)
conflicting_metrics = self._check_metrics_parameters(analyzers, enforce_metric_params)
print(conflicting_metrics)

# Tidies up the metrics_params. If the metrics parameters are consistent, we keep one copy
self.metrics_params = {}
if consistent_params is True:
if analyzers[0].has_extension("quality_metrics") is True:
self.metrics_params["quality_metric_params"] = (
analyzers[0].extensions["quality_metrics"].params["qm_params"]
)
if analyzers[0].has_extension("template_metrics") is True:
self.metrics_params["template_metric_params"] = (
analyzers[0].extensions["template_metrics"].params["qm_params"]
)
# If they are not, we only save the metric names
else:
metrics_params = {}
if analyzers[0].has_extension("quality_metrics") is True:
self.metrics_params["quality_metric_params"] = {}
self.metrics_params["quality_metric_params"]["metric_names"] = (
analyzers[0].extensions["quality_metrics"].params["qm_params"]["metric_names"]
)
if analyzers[0].has_extension("template_metrics") is True:
self.metrics_params["template_metric_params"] = {}
self.metrics_params["template_metric_params"]["metric_names"] = (
analyzers[0].extensions["template_metrics"].params["metric_names"]
)

self.metrics_params = metrics_params
if analyzers[0].has_extension("quality_metrics") is True:
self.metrics_params["quality_metric_params"] = analyzers[0].extensions["quality_metrics"].params
# remove metrics with conflicting params
if len(conflicting_metrics) > 0:
qm_names = self.metrics_params["quality_metric_params"]["metric_names"]
consistent_metrics = list(set(qm_names).difference(set(conflicting_metrics)))
consistent_metric_params = {
metric: analyzers[0].extensions["quality_metrics"].params["qm_params"][metric]
for metric in consistent_metrics
}
self.metrics_params["quality_metric_params"]["qm_params"] = consistent_metric_params

if analyzers[0].has_extension("template_metrics") is True:
self.metrics_params["template_metric_params"] = analyzers[0].extensions["template_metrics"].params
if "template_metrics" in conflicting_metrics:
self.metrics_params["template_metric_params"] = analyzers[0].extensions["template_metrics"].params = {}

self.process_test_data_for_classification()

def _check_metrics_parameters(self, analyzers, enforce_metric_params):
"""Checks that the metrics of each analyzer have been calcualted using the same parameters"""

consistent_params = True
conflicting_metrics = []
for analyzer_index_1, analyzer_1 in enumerate(analyzers):
for analyzer_index_2, analyzer_2 in enumerate(analyzers):

Expand All @@ -279,25 +271,27 @@ def _check_metrics_parameters(self, analyzers, enforce_metric_params):
if analyzer_2.has_extension("template_metrics") is True:
tm_params_2 = analyzer_2.extensions["template_metrics"].params["metrics_kwargs"]

conflicting_metrics = []
conflicting_metrics_between_1_2 = []
# check quality metrics params
for metric, params_1 in qm_params_1.items():
if params_1 != qm_params_2.get(metric):
conflicting_metrics.append(metric)
conflicting_metrics_between_1_2.append(metric)
# check template metric params
for metric, params_1 in tm_params_1.items():
if params_1 != tm_params_2.get(metric):
conflicting_metrics.append("template_metrics")
conflicting_metrics_between_1_2.append("template_metrics")

conflicting_metrics += conflicting_metrics_between_1_2

if len(conflicting_metrics) > 0:
warning_message = f"Parameters used to calculate {conflicting_metrics} are different for sorting_analyzers #{analyzer_index_1} and #{analyzer_index_2}"
if len(conflicting_metrics_between_1_2) > 0:
warning_message = f"Parameters used to calculate {conflicting_metrics_between_1_2} are different for sorting_analyzers #{analyzer_index_1} and #{analyzer_index_2}"
if enforce_metric_params is True:
raise Exception(warning_message)
else:
warnings.warn(warning_message)
consistent_params = False

return consistent_params
unique_conflicting_metrics = set(conflicting_metrics)
return unique_conflicting_metrics

def load_and_preprocess_csv(self, paths):
self._load_data_files(paths)
Expand Down

0 comments on commit dbbb1b8

Please sign in to comment.