Skip to content

Commit

Permalink
Write metric names for csv mode, and don't check params if not in mod…
Browse files Browse the repository at this point in the history
…el_info
  • Loading branch information
chrishalcrow committed Dec 11, 2024
1 parent 52ac6a0 commit fee5f5b
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
8 changes: 5 additions & 3 deletions src/spikeinterface/curation/model_based_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ def predict_labels(
probabilities = np.max(probabilities, axis=1)

if isinstance(label_conversion, dict):
if set(predictions) != set(label_conversion.keys()):

if set(predictions).issubset(set(label_conversion.keys())) is False:
raise ValueError("Labels in predictions do not match those in label_conversion")
predictions = [label_conversion[label] for label in predictions]

Expand Down Expand Up @@ -161,9 +162,10 @@ def _check_params_for_classification(self, enforce_metric_params=False, model_in

# remove the 's' at the end of the extension name
extension_name = extension_name[:-1]
if metric_extension is not None:
model_metric_params = model_info["metric_params"].get(extension_name + "_params")

if metric_extension is not None and model_metric_params is not None:

model_metric_params = model_info["metric_params"][extension_name + "_params"]
metric_params = metric_extension.params["metric_params"]

inconsistent_metrics = []
Expand Down
3 changes: 3 additions & 0 deletions src/spikeinterface/curation/train_manual_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,9 @@ def _check_metrics_parameters(self, analyzers, enforce_metric_params):
def load_and_preprocess_csv(self, paths):
self._load_data_files(paths)
self.process_test_data_for_classification()
self.metrics_params = {}
for metric_name in self.metric_names:
self.metrics_params[metric_name] = {}

def process_test_data_for_classification(self):
"""
Expand Down

0 comments on commit fee5f5b

Please sign in to comment.