From 4d0e259658e485a398d59c066cca02aacb37d9c0 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Thu, 28 Nov 2024 11:51:25 +0000 Subject: [PATCH] respond to train_manual_curation review --- .../curation/model_based_curation.py | 2 +- .../tests/test_train_manual_curation.py | 2 - .../curation/train_manual_curation.py | 47 ++++++++++--------- 3 files changed, 26 insertions(+), 25 deletions(-) diff --git a/src/spikeinterface/curation/model_based_curation.py b/src/spikeinterface/curation/model_based_curation.py index 66fdd6118e..bdd340459c 100644 --- a/src/spikeinterface/curation/model_based_curation.py +++ b/src/spikeinterface/curation/model_based_curation.py @@ -96,7 +96,7 @@ def predict_labels( warnings.warn("Could not find `label_conversion` key in `model_info.json` file") # Prepare input data - input_data = input_data.applymap(lambda x: np.nan if np.isinf(x) else x) + input_data = input_data.map(lambda x: np.nan if np.isinf(x) else x) input_data = input_data.astype("float32") # Apply classifier diff --git a/src/spikeinterface/curation/tests/test_train_manual_curation.py b/src/spikeinterface/curation/tests/test_train_manual_curation.py index 759e560329..59b8565200 100644 --- a/src/spikeinterface/curation/tests/test_train_manual_curation.py +++ b/src/spikeinterface/curation/tests/test_train_manual_curation.py @@ -4,8 +4,6 @@ from pathlib import Path from spikeinterface.curation.tests.common import make_sorting_analyzer - - from spikeinterface.curation.train_manual_curation import CurationModelTrainer, train_model diff --git a/src/spikeinterface/curation/train_manual_curation.py b/src/spikeinterface/curation/train_manual_curation.py index f1539d2e3a..4b8399bc36 100644 --- a/src/spikeinterface/curation/train_manual_curation.py +++ b/src/spikeinterface/curation/train_manual_curation.py @@ -7,6 +7,7 @@ from spikeinterface.qualitymetrics import get_quality_metric_list, get_quality_pca_metric_list from spikeinterface.postprocessing import get_template_metric_names from pathlib import Path +from copy import deepcopy default_classifier_search_spaces = { "RandomForestClassifier": { @@ -157,14 +158,19 @@ def __init__( search_kwargs=None, **job_kwargs, ): + + import pandas as pd + if imputation_strategies is None: imputation_strategies = ["median", "most_frequent", "knn", "iterative"] + if scaling_techniques is None: scaling_techniques = [ "standard_scaler", "min_max_scaler", "robust_scaler", ] + if classifiers is None: self.classifiers = ["RandomForestClassifier"] self.classifier_search_space = None @@ -177,6 +183,10 @@ def __init__( else: raise ValueError("classifiers must be a list or dictionary") + # check if labels is a list of lists + if not all(isinstance(labels, list) for labels in labels): + raise ValueError("labels must be a list of lists") + self.folder = Path(folder) if folder is not None else None self.imputation_strategies = imputation_strategies self.scaling_techniques = scaling_techniques @@ -193,12 +203,6 @@ def __init__( self.requirements = {"spikeinterface": spikeinterface.__version__} - import pandas as pd - - # check if labels is a list of lists - if not all(isinstance(labels, list) for labels in labels): - raise ValueError("labels must be a list of lists") - self.y = pd.concat([pd.DataFrame(one_labels)[0] for one_labels in labels]) self.metric_names = metric_names @@ -246,9 +250,9 @@ def load_and_preprocess_analyzers(self, analyzers, enforce_metric_params): self.metrics_params["quality_metric_params"]["qm_params"] = consistent_metric_params if analyzers[0].has_extension("template_metrics") is True: - self.metrics_params["template_metric_params"] = analyzers[0].extensions["template_metrics"].params + self.metrics_params["template_metric_params"] = deepcopy(analyzers[0].extensions["template_metrics"].params) if "template_metrics" in conflicting_metrics: - self.metrics_params["template_metric_params"] = analyzers[0].extensions["template_metrics"].params = {} + self.metrics_params["template_metric_params"] = {} self.process_test_data_for_classification() @@ -305,7 +309,7 @@ def load_and_preprocess_csv(self, paths): def process_test_data_for_classification(self): """ - Processes the test data for classification. + Cleans the input data so that it can be used by sklearn. Extracts the target variable and features from the loaded dataset. It handles string labels by converting them to integer codes and reindexes the @@ -337,14 +341,13 @@ def process_test_data_for_classification(self): ) self.X = self.testing_metrics[self.metric_names] except KeyError as e: - print("metrics_list contains invalid metric names") - raise e + raise KeyError(f"{str(e)}, metrics_list contains invalid metric names") + self.X = self.testing_metrics.reindex(columns=self.metric_names) self.X = self.X.map(lambda x: np.nan if np.isinf(x) else x) self.X = self.X.astype("float32") - self.X.fillna(0, inplace=True) - def apply_scaling_imputation(self, imputation_strategy, scaling_technique, X_train, X_val, y_train, y_val): + def apply_scaling_imputation(self, imputation_strategy, scaling_technique, X_train, X_test, y_train, y_test): """Impute and scale the data using the specified techniques.""" from sklearn.experimental import enable_iterative_imputer from sklearn.impute import SimpleImputer, KNNImputer, IterativeImputer @@ -371,12 +374,13 @@ def apply_scaling_imputation(self, imputation_strategy, scaling_technique, X_tra f"Unknown scaling technique: {scaling_technique}. Supported scaling techniques are 'standard_scaler', 'min_max_scaler' and 'robust_scaler." ) - y_train = y_train.astype(int) - y_val = y_val.astype(int) + y_train_processed = y_train.astype(int) + y_test = y_test.astype(int) + X_train_imputed = imputer.fit_transform(X_train) - X_val_imputed = imputer.transform(X_val) - X_train_scaled = scaler.fit_transform(X_train_imputed) - X_val_scaled = scaler.transform(X_val_imputed) + X_test_imputed = imputer.transform(X_test) + X_train_processed = scaler.fit_transform(X_train_imputed) + X_test_processed = scaler.transform(X_test_imputed) # Apply SMOTE for class imbalance if self.smote: @@ -385,9 +389,9 @@ def apply_scaling_imputation(self, imputation_strategy, scaling_technique, X_tra except ModuleNotFoundError: raise ModuleNotFoundError("Please install imbalanced-learn package to use SMOTE") smote = SMOTE(random_state=self.seed) - X_train_scaled, y_train = smote.fit_resample(X_train_scaled, y_train) + X_train_processed, y_train_processed = smote.fit_resample(X_train_processed, y_train_processed) - return X_train_scaled, X_val_scaled, y_train, y_val, imputer, scaler + return X_train_processed, X_test_processed, y_train_processed, y_test, imputer, scaler def get_classifier_instance(self, classifier_name): from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier, GradientBoostingClassifier @@ -786,8 +790,7 @@ def check_metric_names_are_the_same(metrics_for_each_analyzer): if metric_names_1 != metric_names_2: metrics_in_1_but_not_2 = metric_names_1.difference(metric_names_2) metrics_in_2_but_not_1 = metric_names_2.difference(metric_names_1) - print(metrics_in_1_but_not_2) - print(metrics_in_2_but_not_1) + error_message = f"Computed metrics are not equal for sorting_analyzers #{j} and #{i}\n" if metrics_in_1_but_not_2: error_message += f"#{j} does not contain {metrics_in_1_but_not_2}, which #{i} does."