diff --git a/src/spikeinterface/curation/tests/test_model_based_curation.py b/src/spikeinterface/curation/tests/test_model_based_curation.py index 2c3515e06b..97547c5f7e 100644 --- a/src/spikeinterface/curation/tests/test_model_based_curation.py +++ b/src/spikeinterface/curation/tests/test_model_based_curation.py @@ -15,26 +15,32 @@ @pytest.fixture def model(): + """A toy model, created using the `sorting_analyzer_for_curation` from `spikeinterface.curation.tests.common`. + It has been trained locally and, when applied to `sorting_analyzer_for_curation` will label its 5 units with + the following labels: [1,0,1,0,1].""" model = load_model(Path(__file__).parent / "trained_pipeline/", trusted=["numpy.dtype"]) - return model @pytest.fixture def required_metrics(): - + """These are the metrics which `model` are trained on.""" return ["num_spikes", "snr", "half_width"] def test_model_based_classification_init(sorting_analyzer_for_curation, model): - # Test the initialization of ModelBasedClassification + """Test that the ModelBasedClassification attributes are correctly initialised""" + model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model[0]) assert model_based_classification.sorting_analyzer == sorting_analyzer_for_curation assert model_based_classification.pipeline == model[0] + assert np.all(model_based_classification.required_metrics == model_based_classification.pipeline.feature_names_in_) def test_metric_ordering_independence(sorting_analyzer_for_curation, model): + """The function `auto_label_units` needs the correct metrics to have been computed. However, + it should be independent of the order of computation. We test this here.""" sorting_analyzer_for_curation.compute("template_metrics", metric_names=["half_width"]) sorting_analyzer_for_curation.compute("quality_metrics", metric_names=["num_spikes", "snr"]) @@ -61,57 +67,35 @@ def test_metric_ordering_independence(sorting_analyzer_for_curation, model): def test_model_based_classification_get_metrics_for_classification( sorting_analyzer_for_curation, model, required_metrics ): + """If the user has not computed the required metrics, an error should be returned. + This test checks that an error occurs when the required metrics have not been computed, + and that no error is returned when the required metrics have been computed. + """ sorting_analyzer_for_curation.delete_extension("quality_metrics") sorting_analyzer_for_curation.delete_extension("template_metrics") - # Test the _check_required_metrics_are_present() method of ModelBasedClassification model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model[0]) - # Check that ValueError is returned when quality_metrics are not present in sorting_analyzer + # Check that ValueError is returned when no metrics are present in sorting_analyzer with pytest.raises(ValueError): computed_metrics = _get_computed_metrics(sorting_analyzer_for_curation) - # Compute some (but not all) of the required metrics in sorting_analyzer + # Compute some (but not all) of the required metrics in sorting_analyzer, should still error sorting_analyzer_for_curation.compute("quality_metrics", metric_names=[required_metrics[0]]) computed_metrics = _get_computed_metrics(sorting_analyzer_for_curation) with pytest.raises(ValueError): model_based_classification._check_required_metrics_are_present(computed_metrics) - # Compute all of the required metrics in sorting_analyzer + # Compute all of the required metrics in sorting_analyzer, no more error sorting_analyzer_for_curation.compute("quality_metrics", metric_names=required_metrics[0:2]) sorting_analyzer_for_curation.compute("template_metrics", metric_names=[required_metrics[2]]) - # Check that the metrics data is returned as a pandas DataFrame metrics_data = _get_computed_metrics(sorting_analyzer_for_curation) assert metrics_data.shape[0] == len(sorting_analyzer_for_curation.sorting.get_unit_ids()) assert set(metrics_data.columns.to_list()) == set(required_metrics) -def test_model_based_classification_check_params_for_classification( - sorting_analyzer_for_curation, model, required_metrics -): - # Make a fresh copy of the sorting_analyzer to remove any calculated metrics - sorting_analyzer_for_curation = make_sorting_analyzer() - - # Test the _check_params_for_classification() method of ModelBasedClassification - model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model[0]) - - # Check that function runs without error when required_metrics are computed - sorting_analyzer_for_curation.compute("quality_metrics", metric_names=required_metrics[0:2]) - sorting_analyzer_for_curation.compute("template_metrics", metric_names=[required_metrics[2]]) - - model_info = {"metric_params": {}} - model_info["metric_params"]["quality_metric_params"] = sorting_analyzer_for_curation.get_extension( - "quality_metrics" - ).params - model_info["metric_params"]["template_metric_params"] = sorting_analyzer_for_curation.get_extension( - "template_metrics" - ).params - - model_based_classification._check_params_for_classification(model_info=model_info) - - def test_model_based_classification_export_to_phy(sorting_analyzer_for_curation, model): # Test the _export_to_phy() method of ModelBasedClassification model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model[0]) @@ -145,7 +129,38 @@ def test_model_based_classification_predict_labels(sorting_analyzer_for_curation assert np.all(predictions_labelled == ["good", "noise", "good", "noise", "good"]) +def test_model_based_classification_check_params_for_classification( + sorting_analyzer_for_curation, model, required_metrics +): + """ """ + # Make a fresh copy of the sorting_analyzer to remove any calculated metrics + sorting_analyzer_for_curation.delete_extension("quality_metrics") + sorting_analyzer_for_curation.delete_extension("template_metrics") + + # Test the _check_params_for_classification() method of ModelBasedClassification + model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model[0]) + + # Check that function runs without error when required_metrics are computed + sorting_analyzer_for_curation.compute("quality_metrics", metric_names=required_metrics[0:2]) + sorting_analyzer_for_curation.compute("template_metrics", metric_names=[required_metrics[2]]) + + model_info = {"metric_params": {}} + model_info["metric_params"]["quality_metric_params"] = sorting_analyzer_for_curation.get_extension( + "quality_metrics" + ).params + model_info["metric_params"]["template_metric_params"] = sorting_analyzer_for_curation.get_extension( + "template_metrics" + ).params + + model_based_classification._check_params_for_classification(model_info=model_info) + + def test_exception_raised_when_metricparams_not_equal(sorting_analyzer_for_curation): + """We track whether the metric parameters used to compute the metrics used to train + a model are the same as the parameters used to compute the metrics in the sorting + analyzer which is being curated. If they are different, an error or warning will + be raised depending on the `enforce_metric_params` kwarg. This behaviour is tested here.""" + sorting_analyzer_for_curation.compute( "quality_metrics", metric_names=["num_spikes", "snr"], qm_params={"snr": {"peak_mode": "peak_to_peak"}} ) @@ -153,27 +168,21 @@ def test_exception_raised_when_metricparams_not_equal(sorting_analyzer_for_curat model_folder = Path(__file__).parent / Path("trained_pipeline") + model, model_info = load_model(model_folder=model_folder, trusted=["numpy.dtype"]) + model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model) + # an error should be raised if `enforce_metric_params` is True with pytest.raises(Exception): - auto_label_units( - sorting_analyzer=sorting_analyzer_for_curation, - model_folder=model_folder, - enforce_metric_params=True, - trusted=["numpy.dtype"], - ) - - # but not if `enforce_metric_params` is False - auto_label_units( - sorting_analyzer=sorting_analyzer_for_curation, - model_folder=model_folder, - enforce_metric_params=False, - trusted=["numpy.dtype"], - ) + model_based_classification._check_params_for_classification(enforce_metric_params=True, model_info=model_info) + + # but only a warning if `enforce_metric_params` is False + with pytest.warns(UserWarning): + model_based_classification._check_params_for_classification(enforce_metric_params=False, model_info=model_info) - classifer_labels = sorting_analyzer_for_curation.get_sorting_property("classifier_label") - assert isinstance(classifer_labels, np.ndarray) - assert len(classifer_labels) == sorting_analyzer_for_curation.get_num_units() + # Now test the positive case. Recompute using the default parameters + sorting_analyzer_for_curation.compute("quality_metrics", metric_names=["num_spikes", "snr"], qm_params={}) + sorting_analyzer_for_curation.compute("template_metrics", metric_names=["half_width"]) - classifier_probabilities = sorting_analyzer_for_curation.get_sorting_property("classifier_probability") - assert isinstance(classifier_probabilities, np.ndarray) - assert len(classifier_probabilities) == sorting_analyzer_for_curation.get_num_units() + model, model_info = load_model(model_folder=model_folder, trusted=["numpy.dtype"]) + model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model) + model_based_classification._check_params_for_classification(enforce_metric_params=True, model_info=model_info) diff --git a/src/spikeinterface/curation/tests/test_train_manual_curation.py b/src/spikeinterface/curation/tests/test_train_manual_curation.py index 59b8565200..d6e9c97f55 100644 --- a/src/spikeinterface/curation/tests/test_train_manual_curation.py +++ b/src/spikeinterface/curation/tests/test_train_manual_curation.py @@ -9,6 +9,8 @@ @pytest.fixture def trainer(): + """A simple CurationModelTrainer object is created, which can later by used to + train models using data from `sorting_analyzer`s.""" folder = tempfile.mkdtemp() # Create a temporary output folder imputation_strategies = ["median"] @@ -26,7 +28,10 @@ def trainer(): def make_temp_training_csv(): - # Create a temporary CSV file with sham data + """Create a temporary CSV file with artificially generated quality metrics. + The data is designed to be easy to dicern between units. Even units metric + values are all `0`, while odd units metric values are all `1`. + """ with tempfile.NamedTemporaryFile(mode="w", delete=False) as temp_file: writer = csv.writer(temp_file) writer.writerow(["unit_id", "metric1", "metric2", "metric3"]) @@ -37,70 +42,135 @@ def make_temp_training_csv(): def test_load_and_preprocess_full(trainer): + """Check that we load and preprocess the csv file from `make_temp_training_csv` + correctly.""" temp_file_path = make_temp_training_csv() # Load and preprocess the data from the temporary CSV file trainer.load_and_preprocess_csv([temp_file_path]) # Assert that the data is loaded and preprocessed correctly - assert trainer.X is not None - assert trainer.y is not None - assert trainer.testing_metrics is not None + for a, row in trainer.X.iterrows(): + assert np.all(row.values == [float(a % 2)] * 3) + for a, label in enumerate(trainer.y.values): + assert label == a % 2 + for a, row in trainer.testing_metrics.iterrows(): + assert np.all(row.values == [a % 2] * 3) + assert row.name == a def test_apply_scaling_imputation(trainer): + """Take a simple training and test set and check that they are corrected scaled, + using a standard scaler which rescales the training distribution to have mean 0 + and variance 1. Length between each row is 3, so if x0 is the first value in the + column, all other values are scaled as x -> 2/3(x - x0) - 1. The y (labled) values + do not get scaled.""" + + from sklearn.impute._knn import KNNImputer + from sklearn.preprocessing._data import StandardScaler imputation_strategy = "knn" scaling_technique = "standard_scaler" X_train = np.array([[1, 2, 3], [4, 5, 6]]) - X_val = np.array([[7, 8, 9], [10, 11, 12]]) + X_test = np.array([[7, 8, 9], [10, 11, 12]]) y_train = np.array([0, 1]) - y_val = np.array([2, 3]) - X_train_scaled, X_val_scaled, y_train, y_val, imputer, scaler = trainer.apply_scaling_imputation( - imputation_strategy, scaling_technique, X_train, X_val, y_train, y_val + y_test = np.array([2, 3]) + + X_train_scaled, X_test_scaled, y_train_scaled, y_test_scaled, imputer, scaler = trainer.apply_scaling_imputation( + imputation_strategy, scaling_technique, X_train, X_test, y_train, y_test ) - assert X_train_scaled is not None - assert X_val_scaled is not None - assert y_train is not None - assert y_val is not None - assert imputer is not None - assert scaler is not None + + first_row_elements = X_train[0] + for a, row in enumerate(X_train): + assert np.all(2 / 3 * (row - first_row_elements) - 1.0 == X_train_scaled[a]) + for a, row in enumerate(X_test): + assert np.all(2 / 3 * (row - first_row_elements) - 1.0 == X_test_scaled[a]) + + assert np.all(y_train == y_train_scaled) + assert np.all(y_test == y_test_scaled) + + print(type(scaler)) + + assert isinstance(imputer, KNNImputer) + assert isinstance(scaler, StandardScaler) def test_get_classifier_search_space(trainer): + """For each classifier, there is a hyperparameter space we search over to find its + most accurate incarnation. Here, we check that we do indeed load the approprirate + dict of hyperparameter possibilities""" + + from sklearn.linear_model._logistic import LogisticRegression classifier = "LogisticRegression" model, param_space = trainer.get_classifier_search_space(classifier) - assert model is not None + + assert isinstance(model, LogisticRegression) + assert len(param_space) > 0 assert isinstance(param_space, dict) def test_get_custom_classifier_search_space(): + """Check that if a user passes a custom hyperparameter search space, that this is + passed correctly to the trainer.""" + classifier = { "LogisticRegression": { - "C": [0.001, 8.0], - "solver": ["newton-cg", "lbfgs", "liblinear", "sag", "saga"], + "C": [0.1, 8.0], + "solver": ["lbfgs"], "max_iter": [100, 400], } } trainer = CurationModelTrainer(classifiers=classifier, labels=[[0, 1, 0, 1, 0, 1, 0, 1, 0, 1]]) model, param_space = trainer.get_classifier_search_space(list(classifier.keys())[0]) - assert model is not None assert param_space == classifier["LogisticRegression"] -def test_evaluate_model_config(trainer): +def test_saved_files(trainer): + """During the trainer's creation, the following files should be created: + - best_model.skops + - labels.csv + - model_accuracies.csv + - model_info.json + - training_data.csv + This test checks that these exist, and checks some properties of the files.""" + + import pandas as pd + import json - trainer.X = np.ones((10, 3)) + trainer.X = np.random.rand(10, 3) trainer.y = np.append(np.ones(5), np.zeros(5)) trainer.evaluate_model_config() trainer_folder = Path(trainer.folder) + assert trainer_folder.is_dir() - assert (trainer_folder / "best_model.skops").is_file() - assert (trainer_folder / "model_accuracies.csv").is_file() - assert (trainer_folder / "model_info.json").is_file() + + best_model_path = trainer_folder / "best_model.skops" + model_accuracies_path = trainer_folder / "model_accuracies.csv" + training_data_path = trainer_folder / "training_data.csv" + labels_path = trainer_folder / "labels.csv" + model_info_path = trainer_folder / "model_info.json" + + assert (best_model_path).is_file() + + model_accuracies = pd.read_csv(model_accuracies_path) + model_accuracies["classifier name"].values[0] == "LogisticRegression" + assert len(model_accuracies) == 1 + + training_data = pd.read_csv(training_data_path) + assert np.all(np.isclose(training_data.values[:, 1:4], trainer.X, rtol=1e-10)) + + labels = pd.read_csv(labels_path) + assert np.all(labels.values[:, 1] == trainer.y.astype("float")) + + model_info = pd.read_json(model_info_path) + + with open(model_info_path) as f: + model_info = json.load(f) + + assert set(model_info.keys()) == set(["metric_params", "requirements", "label_conversion"]) def test_train_model_using_two_csvs():