Skip to content

Commit

Permalink
Change metric_path to metric_pathS
Browse files Browse the repository at this point in the history
  • Loading branch information
chrishalcrow committed Oct 11, 2024
1 parent e7904e1 commit 2c9a39b
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 13 deletions.
4 changes: 2 additions & 2 deletions doc/how_to/auto_curation_training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,14 @@ We can also access the model, which is an sklearn ``Pipeline``, from the trainer
The training function can also be run in “csv” mode, if you prefer to
store metrics in a single .csv file. If the target labels are stored as a column in
store metrics in as .csv files. If the target labels are stored as a column in
the file, you can point to these with the ``target_label`` parameter

.. code::
trainer = train_model(
mode="csv",
metrics_path = "/path/to/csv",
metrics_paths = ["/path/to/csv_file_1", "/path/to/csv_file_2"],
target_label = "my_label",
output_folder=output_folder,
)
26 changes: 24 additions & 2 deletions src/spikeinterface/curation/tests/test_train_manual_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_load_and_preprocess_full(trainer):
temp_file_path = make_temp_training_csv()

# Load and preprocess the data from the temporary CSV file
trainer.load_and_preprocess_csv(temp_file_path)
trainer.load_and_preprocess_csv([temp_file_path])

# Assert that the data is loaded and preprocessed correctly
assert trainer.X is not None
Expand Down Expand Up @@ -102,14 +102,36 @@ def test_evaluate_model_config(trainer):
assert (trainer_folder / "model_info.json").is_file()


def test_train_model_using_two_csvs():

metrics_path_1 = make_temp_training_csv()
metrics_path_2 = make_temp_training_csv()

folder = tempfile.mkdtemp()
metric_names = ["metric1", "metric2", "metric3"]

trainer = train_model(
mode="csv",
metrics_paths=[metrics_path_1, metrics_path_2],
folder=folder,
labels=[[0, 1, 0, 1, 0, 1, 0, 1, 0, 1], [0, 1, 0, 1, 0, 1, 0, 1, 0, 1]],
metric_names=metric_names,
imputation_strategies=["median"],
scaling_techniques=["standard_scaler"],
classifiers=["LogisticRegression"],
overwrite=True,
)
assert isinstance(trainer, CurationModelTrainer)


def test_train_model():

metrics_path = make_temp_training_csv()
folder = tempfile.mkdtemp()
metric_names = ["metric1", "metric2", "metric3"]
trainer = train_model(
mode="csv",
metrics_path=metrics_path,
metrics_paths=[metrics_path],
folder=folder,
labels=[[0, 1, 0, 1, 0, 1, 0, 1, 0, 1]],
metric_names=metric_names,
Expand Down
19 changes: 10 additions & 9 deletions src/spikeinterface/curation/train_manual_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,8 @@ def _check_metrics_parameters(self):
"same parameters for each sorting_analyzer."
)

def load_and_preprocess_csv(self, path):
self._load_data_file(path)
def load_and_preprocess_csv(self, paths):
self._load_data_files(paths)
self.process_test_data_for_classification()

def process_test_data_for_classification(self):
Expand Down Expand Up @@ -443,10 +443,10 @@ def _get_metrics_for_classification(self, analyzer, analyzer_index):

return calculated_metrics

def _load_data_file(self, path):
def _load_data_files(self, paths):
import pandas as pd

self.testing_metrics = pd.read_csv(path, index_col=0)
self.testing_metrics = pd.concat([pd.read_csv(path, index_col=0) for path in paths], axis=0)

def _evaluate(self, imputation_strategies, scaling_techniques, classifiers, X_train, X_test, y_train, y_test):
from joblib import Parallel, delayed
Expand Down Expand Up @@ -552,7 +552,7 @@ def train_model(
mode="analyzers",
labels=None,
analyzers=None,
metrics_path=None,
metrics_paths=None,
folder=None,
metric_names=None,
imputation_strategies=None,
Expand All @@ -577,8 +577,8 @@ def train_model(
List of SortingAnalyzer objects containing the quality metrics and labels to use for training, if using 'analyzers' mode.
labels : list of list | None, default: None
List of curated labels for each unit; must be in the same order as the metrics data.
metrics_path : str or None, default: None
The path to the CSV file containing the metrics data if using 'csv' mode.
metrics_paths : list of str or None, default: None
List of paths to the CSV files containing the metrics data if using 'csv' mode.
folder : str | None, default: None
The folder where outputs such as models and evaluation metrics will be saved.
metric_names : list of str | None, default: None
Expand Down Expand Up @@ -633,8 +633,9 @@ def train_model(
trainer.load_and_preprocess_analyzers(analyzers)

elif mode == "csv":
assert Path(metrics_path).is_file(), "Valid metrics path must be provided for mode 'csv'"
trainer.load_and_preprocess_csv(metrics_path)
for metrics_path in metrics_paths:
assert Path(metrics_path).is_file(), f"{metrics_path} is not a file."
trainer.load_and_preprocess_csv(metrics_paths)

trainer.evaluate_model_config()
return trainer
Expand Down

0 comments on commit 2c9a39b

Please sign in to comment.