Skip to content

Commit

Permalink
respond to train_manual_curation review
Browse files Browse the repository at this point in the history
  • Loading branch information
chrishalcrow committed Nov 28, 2024
1 parent 65847c1 commit 4d0e259
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 25 deletions.
2 changes: 1 addition & 1 deletion src/spikeinterface/curation/model_based_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def predict_labels(
warnings.warn("Could not find `label_conversion` key in `model_info.json` file")

# Prepare input data
input_data = input_data.applymap(lambda x: np.nan if np.isinf(x) else x)
input_data = input_data.map(lambda x: np.nan if np.isinf(x) else x)
input_data = input_data.astype("float32")

# Apply classifier
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
from pathlib import Path

from spikeinterface.curation.tests.common import make_sorting_analyzer


from spikeinterface.curation.train_manual_curation import CurationModelTrainer, train_model


Expand Down
47 changes: 25 additions & 22 deletions src/spikeinterface/curation/train_manual_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from spikeinterface.qualitymetrics import get_quality_metric_list, get_quality_pca_metric_list
from spikeinterface.postprocessing import get_template_metric_names
from pathlib import Path
from copy import deepcopy

default_classifier_search_spaces = {
"RandomForestClassifier": {
Expand Down Expand Up @@ -157,14 +158,19 @@ def __init__(
search_kwargs=None,
**job_kwargs,
):

import pandas as pd

if imputation_strategies is None:
imputation_strategies = ["median", "most_frequent", "knn", "iterative"]

if scaling_techniques is None:
scaling_techniques = [
"standard_scaler",
"min_max_scaler",
"robust_scaler",
]

if classifiers is None:
self.classifiers = ["RandomForestClassifier"]
self.classifier_search_space = None
Expand All @@ -177,6 +183,10 @@ def __init__(
else:
raise ValueError("classifiers must be a list or dictionary")

# check if labels is a list of lists
if not all(isinstance(labels, list) for labels in labels):
raise ValueError("labels must be a list of lists")

self.folder = Path(folder) if folder is not None else None
self.imputation_strategies = imputation_strategies
self.scaling_techniques = scaling_techniques
Expand All @@ -193,12 +203,6 @@ def __init__(

self.requirements = {"spikeinterface": spikeinterface.__version__}

import pandas as pd

# check if labels is a list of lists
if not all(isinstance(labels, list) for labels in labels):
raise ValueError("labels must be a list of lists")

self.y = pd.concat([pd.DataFrame(one_labels)[0] for one_labels in labels])

self.metric_names = metric_names
Expand Down Expand Up @@ -246,9 +250,9 @@ def load_and_preprocess_analyzers(self, analyzers, enforce_metric_params):
self.metrics_params["quality_metric_params"]["qm_params"] = consistent_metric_params

if analyzers[0].has_extension("template_metrics") is True:
self.metrics_params["template_metric_params"] = analyzers[0].extensions["template_metrics"].params
self.metrics_params["template_metric_params"] = deepcopy(analyzers[0].extensions["template_metrics"].params)
if "template_metrics" in conflicting_metrics:
self.metrics_params["template_metric_params"] = analyzers[0].extensions["template_metrics"].params = {}
self.metrics_params["template_metric_params"] = {}

self.process_test_data_for_classification()

Expand Down Expand Up @@ -305,7 +309,7 @@ def load_and_preprocess_csv(self, paths):

def process_test_data_for_classification(self):
"""
Processes the test data for classification.
Cleans the input data so that it can be used by sklearn.
Extracts the target variable and features from the loaded dataset.
It handles string labels by converting them to integer codes and reindexes the
Expand Down Expand Up @@ -337,14 +341,13 @@ def process_test_data_for_classification(self):
)
self.X = self.testing_metrics[self.metric_names]
except KeyError as e:
print("metrics_list contains invalid metric names")
raise e
raise KeyError(f"{str(e)}, metrics_list contains invalid metric names")

self.X = self.testing_metrics.reindex(columns=self.metric_names)
self.X = self.X.map(lambda x: np.nan if np.isinf(x) else x)
self.X = self.X.astype("float32")
self.X.fillna(0, inplace=True)

def apply_scaling_imputation(self, imputation_strategy, scaling_technique, X_train, X_val, y_train, y_val):
def apply_scaling_imputation(self, imputation_strategy, scaling_technique, X_train, X_test, y_train, y_test):
"""Impute and scale the data using the specified techniques."""
from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import SimpleImputer, KNNImputer, IterativeImputer
Expand All @@ -371,12 +374,13 @@ def apply_scaling_imputation(self, imputation_strategy, scaling_technique, X_tra
f"Unknown scaling technique: {scaling_technique}. Supported scaling techniques are 'standard_scaler', 'min_max_scaler' and 'robust_scaler."
)

y_train = y_train.astype(int)
y_val = y_val.astype(int)
y_train_processed = y_train.astype(int)
y_test = y_test.astype(int)

X_train_imputed = imputer.fit_transform(X_train)
X_val_imputed = imputer.transform(X_val)
X_train_scaled = scaler.fit_transform(X_train_imputed)
X_val_scaled = scaler.transform(X_val_imputed)
X_test_imputed = imputer.transform(X_test)
X_train_processed = scaler.fit_transform(X_train_imputed)
X_test_processed = scaler.transform(X_test_imputed)

# Apply SMOTE for class imbalance
if self.smote:
Expand All @@ -385,9 +389,9 @@ def apply_scaling_imputation(self, imputation_strategy, scaling_technique, X_tra
except ModuleNotFoundError:
raise ModuleNotFoundError("Please install imbalanced-learn package to use SMOTE")
smote = SMOTE(random_state=self.seed)
X_train_scaled, y_train = smote.fit_resample(X_train_scaled, y_train)
X_train_processed, y_train_processed = smote.fit_resample(X_train_processed, y_train_processed)

return X_train_scaled, X_val_scaled, y_train, y_val, imputer, scaler
return X_train_processed, X_test_processed, y_train_processed, y_test, imputer, scaler

def get_classifier_instance(self, classifier_name):
from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier, GradientBoostingClassifier
Expand Down Expand Up @@ -786,8 +790,7 @@ def check_metric_names_are_the_same(metrics_for_each_analyzer):
if metric_names_1 != metric_names_2:
metrics_in_1_but_not_2 = metric_names_1.difference(metric_names_2)
metrics_in_2_but_not_1 = metric_names_2.difference(metric_names_1)
print(metrics_in_1_but_not_2)
print(metrics_in_2_but_not_1)

error_message = f"Computed metrics are not equal for sorting_analyzers #{j} and #{i}\n"
if metrics_in_1_but_not_2:
error_message += f"#{j} does not contain {metrics_in_1_but_not_2}, which #{i} does."
Expand Down

0 comments on commit 4d0e259

Please sign in to comment.