Skip to content

Commit

Permalink
start responding to test comments
Browse files Browse the repository at this point in the history
  • Loading branch information
chrishalcrow committed Nov 29, 2024
1 parent 8dacf2a commit 830a977
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 75 deletions.
113 changes: 61 additions & 52 deletions src/spikeinterface/curation/tests/test_model_based_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,32 @@

@pytest.fixture
def model():
"""A toy model, created using the `sorting_analyzer_for_curation` from `spikeinterface.curation.tests.common`.
It has been trained locally and, when applied to `sorting_analyzer_for_curation` will label its 5 units with
the following labels: [1,0,1,0,1]."""

model = load_model(Path(__file__).parent / "trained_pipeline/", trusted=["numpy.dtype"])

return model


@pytest.fixture
def required_metrics():

"""These are the metrics which `model` are trained on."""
return ["num_spikes", "snr", "half_width"]


def test_model_based_classification_init(sorting_analyzer_for_curation, model):
# Test the initialization of ModelBasedClassification
"""Test that the ModelBasedClassification attributes are correctly initialised"""

model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model[0])
assert model_based_classification.sorting_analyzer == sorting_analyzer_for_curation
assert model_based_classification.pipeline == model[0]
assert np.all(model_based_classification.required_metrics == model_based_classification.pipeline.feature_names_in_)


def test_metric_ordering_independence(sorting_analyzer_for_curation, model):
"""The function `auto_label_units` needs the correct metrics to have been computed. However,
it should be independent of the order of computation. We test this here."""

sorting_analyzer_for_curation.compute("template_metrics", metric_names=["half_width"])
sorting_analyzer_for_curation.compute("quality_metrics", metric_names=["num_spikes", "snr"])
Expand All @@ -61,57 +67,35 @@ def test_metric_ordering_independence(sorting_analyzer_for_curation, model):
def test_model_based_classification_get_metrics_for_classification(
sorting_analyzer_for_curation, model, required_metrics
):
"""If the user has not computed the required metrics, an error should be returned.
This test checks that an error occurs when the required metrics have not been computed,
and that no error is returned when the required metrics have been computed.
"""

sorting_analyzer_for_curation.delete_extension("quality_metrics")
sorting_analyzer_for_curation.delete_extension("template_metrics")

# Test the _check_required_metrics_are_present() method of ModelBasedClassification
model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model[0])

# Check that ValueError is returned when quality_metrics are not present in sorting_analyzer
# Check that ValueError is returned when no metrics are present in sorting_analyzer
with pytest.raises(ValueError):
computed_metrics = _get_computed_metrics(sorting_analyzer_for_curation)

# Compute some (but not all) of the required metrics in sorting_analyzer
# Compute some (but not all) of the required metrics in sorting_analyzer, should still error
sorting_analyzer_for_curation.compute("quality_metrics", metric_names=[required_metrics[0]])
computed_metrics = _get_computed_metrics(sorting_analyzer_for_curation)
with pytest.raises(ValueError):
model_based_classification._check_required_metrics_are_present(computed_metrics)

# Compute all of the required metrics in sorting_analyzer
# Compute all of the required metrics in sorting_analyzer, no more error
sorting_analyzer_for_curation.compute("quality_metrics", metric_names=required_metrics[0:2])
sorting_analyzer_for_curation.compute("template_metrics", metric_names=[required_metrics[2]])

# Check that the metrics data is returned as a pandas DataFrame
metrics_data = _get_computed_metrics(sorting_analyzer_for_curation)
assert metrics_data.shape[0] == len(sorting_analyzer_for_curation.sorting.get_unit_ids())
assert set(metrics_data.columns.to_list()) == set(required_metrics)


def test_model_based_classification_check_params_for_classification(
sorting_analyzer_for_curation, model, required_metrics
):
# Make a fresh copy of the sorting_analyzer to remove any calculated metrics
sorting_analyzer_for_curation = make_sorting_analyzer()

# Test the _check_params_for_classification() method of ModelBasedClassification
model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model[0])

# Check that function runs without error when required_metrics are computed
sorting_analyzer_for_curation.compute("quality_metrics", metric_names=required_metrics[0:2])
sorting_analyzer_for_curation.compute("template_metrics", metric_names=[required_metrics[2]])

model_info = {"metric_params": {}}
model_info["metric_params"]["quality_metric_params"] = sorting_analyzer_for_curation.get_extension(
"quality_metrics"
).params
model_info["metric_params"]["template_metric_params"] = sorting_analyzer_for_curation.get_extension(
"template_metrics"
).params

model_based_classification._check_params_for_classification(model_info=model_info)


def test_model_based_classification_export_to_phy(sorting_analyzer_for_curation, model):
# Test the _export_to_phy() method of ModelBasedClassification
model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model[0])
Expand Down Expand Up @@ -145,35 +129,60 @@ def test_model_based_classification_predict_labels(sorting_analyzer_for_curation
assert np.all(predictions_labelled == ["good", "noise", "good", "noise", "good"])


def test_model_based_classification_check_params_for_classification(
sorting_analyzer_for_curation, model, required_metrics
):
""" """
# Make a fresh copy of the sorting_analyzer to remove any calculated metrics
sorting_analyzer_for_curation.delete_extension("quality_metrics")
sorting_analyzer_for_curation.delete_extension("template_metrics")

# Test the _check_params_for_classification() method of ModelBasedClassification
model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model[0])

# Check that function runs without error when required_metrics are computed
sorting_analyzer_for_curation.compute("quality_metrics", metric_names=required_metrics[0:2])
sorting_analyzer_for_curation.compute("template_metrics", metric_names=[required_metrics[2]])

model_info = {"metric_params": {}}
model_info["metric_params"]["quality_metric_params"] = sorting_analyzer_for_curation.get_extension(
"quality_metrics"
).params
model_info["metric_params"]["template_metric_params"] = sorting_analyzer_for_curation.get_extension(
"template_metrics"
).params

model_based_classification._check_params_for_classification(model_info=model_info)


def test_exception_raised_when_metricparams_not_equal(sorting_analyzer_for_curation):
"""We track whether the metric parameters used to compute the metrics used to train
a model are the same as the parameters used to compute the metrics in the sorting
analyzer which is being curated. If they are different, an error or warning will
be raised depending on the `enforce_metric_params` kwarg. This behaviour is tested here."""

sorting_analyzer_for_curation.compute(
"quality_metrics", metric_names=["num_spikes", "snr"], qm_params={"snr": {"peak_mode": "peak_to_peak"}}
)
sorting_analyzer_for_curation.compute("template_metrics", metric_names=["half_width"])

model_folder = Path(__file__).parent / Path("trained_pipeline")

model, model_info = load_model(model_folder=model_folder, trusted=["numpy.dtype"])
model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model)

# an error should be raised if `enforce_metric_params` is True
with pytest.raises(Exception):
auto_label_units(
sorting_analyzer=sorting_analyzer_for_curation,
model_folder=model_folder,
enforce_metric_params=True,
trusted=["numpy.dtype"],
)

# but not if `enforce_metric_params` is False
auto_label_units(
sorting_analyzer=sorting_analyzer_for_curation,
model_folder=model_folder,
enforce_metric_params=False,
trusted=["numpy.dtype"],
)
model_based_classification._check_params_for_classification(enforce_metric_params=True, model_info=model_info)

# but only a warning if `enforce_metric_params` is False
with pytest.warns(UserWarning):
model_based_classification._check_params_for_classification(enforce_metric_params=False, model_info=model_info)

classifer_labels = sorting_analyzer_for_curation.get_sorting_property("classifier_label")
assert isinstance(classifer_labels, np.ndarray)
assert len(classifer_labels) == sorting_analyzer_for_curation.get_num_units()
# Now test the positive case. Recompute using the default parameters
sorting_analyzer_for_curation.compute("quality_metrics", metric_names=["num_spikes", "snr"], qm_params={})
sorting_analyzer_for_curation.compute("template_metrics", metric_names=["half_width"])

classifier_probabilities = sorting_analyzer_for_curation.get_sorting_property("classifier_probability")
assert isinstance(classifier_probabilities, np.ndarray)
assert len(classifier_probabilities) == sorting_analyzer_for_curation.get_num_units()
model, model_info = load_model(model_folder=model_folder, trusted=["numpy.dtype"])
model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model)
model_based_classification._check_params_for_classification(enforce_metric_params=True, model_info=model_info)
116 changes: 93 additions & 23 deletions src/spikeinterface/curation/tests/test_train_manual_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

@pytest.fixture
def trainer():
"""A simple CurationModelTrainer object is created, which can later by used to
train models using data from `sorting_analyzer`s."""

folder = tempfile.mkdtemp() # Create a temporary output folder
imputation_strategies = ["median"]
Expand All @@ -26,7 +28,10 @@ def trainer():


def make_temp_training_csv():
# Create a temporary CSV file with sham data
"""Create a temporary CSV file with artificially generated quality metrics.
The data is designed to be easy to dicern between units. Even units metric
values are all `0`, while odd units metric values are all `1`.
"""
with tempfile.NamedTemporaryFile(mode="w", delete=False) as temp_file:
writer = csv.writer(temp_file)
writer.writerow(["unit_id", "metric1", "metric2", "metric3"])
Expand All @@ -37,70 +42,135 @@ def make_temp_training_csv():


def test_load_and_preprocess_full(trainer):
"""Check that we load and preprocess the csv file from `make_temp_training_csv`
correctly."""
temp_file_path = make_temp_training_csv()

# Load and preprocess the data from the temporary CSV file
trainer.load_and_preprocess_csv([temp_file_path])

# Assert that the data is loaded and preprocessed correctly
assert trainer.X is not None
assert trainer.y is not None
assert trainer.testing_metrics is not None
for a, row in trainer.X.iterrows():
assert np.all(row.values == [float(a % 2)] * 3)
for a, label in enumerate(trainer.y.values):
assert label == a % 2
for a, row in trainer.testing_metrics.iterrows():
assert np.all(row.values == [a % 2] * 3)
assert row.name == a


def test_apply_scaling_imputation(trainer):
"""Take a simple training and test set and check that they are corrected scaled,
using a standard scaler which rescales the training distribution to have mean 0
and variance 1. Length between each row is 3, so if x0 is the first value in the
column, all other values are scaled as x -> 2/3(x - x0) - 1. The y (labled) values
do not get scaled."""

from sklearn.impute._knn import KNNImputer
from sklearn.preprocessing._data import StandardScaler

imputation_strategy = "knn"
scaling_technique = "standard_scaler"
X_train = np.array([[1, 2, 3], [4, 5, 6]])
X_val = np.array([[7, 8, 9], [10, 11, 12]])
X_test = np.array([[7, 8, 9], [10, 11, 12]])
y_train = np.array([0, 1])
y_val = np.array([2, 3])
X_train_scaled, X_val_scaled, y_train, y_val, imputer, scaler = trainer.apply_scaling_imputation(
imputation_strategy, scaling_technique, X_train, X_val, y_train, y_val
y_test = np.array([2, 3])

X_train_scaled, X_test_scaled, y_train_scaled, y_test_scaled, imputer, scaler = trainer.apply_scaling_imputation(
imputation_strategy, scaling_technique, X_train, X_test, y_train, y_test
)
assert X_train_scaled is not None
assert X_val_scaled is not None
assert y_train is not None
assert y_val is not None
assert imputer is not None
assert scaler is not None

first_row_elements = X_train[0]
for a, row in enumerate(X_train):
assert np.all(2 / 3 * (row - first_row_elements) - 1.0 == X_train_scaled[a])
for a, row in enumerate(X_test):
assert np.all(2 / 3 * (row - first_row_elements) - 1.0 == X_test_scaled[a])

assert np.all(y_train == y_train_scaled)
assert np.all(y_test == y_test_scaled)

print(type(scaler))

assert isinstance(imputer, KNNImputer)
assert isinstance(scaler, StandardScaler)


def test_get_classifier_search_space(trainer):
"""For each classifier, there is a hyperparameter space we search over to find its
most accurate incarnation. Here, we check that we do indeed load the approprirate
dict of hyperparameter possibilities"""

from sklearn.linear_model._logistic import LogisticRegression

classifier = "LogisticRegression"
model, param_space = trainer.get_classifier_search_space(classifier)
assert model is not None

assert isinstance(model, LogisticRegression)
assert len(param_space) > 0
assert isinstance(param_space, dict)


def test_get_custom_classifier_search_space():
"""Check that if a user passes a custom hyperparameter search space, that this is
passed correctly to the trainer."""

classifier = {
"LogisticRegression": {
"C": [0.001, 8.0],
"solver": ["newton-cg", "lbfgs", "liblinear", "sag", "saga"],
"C": [0.1, 8.0],
"solver": ["lbfgs"],
"max_iter": [100, 400],
}
}
trainer = CurationModelTrainer(classifiers=classifier, labels=[[0, 1, 0, 1, 0, 1, 0, 1, 0, 1]])

model, param_space = trainer.get_classifier_search_space(list(classifier.keys())[0])
assert model is not None
assert param_space == classifier["LogisticRegression"]


def test_evaluate_model_config(trainer):
def test_saved_files(trainer):
"""During the trainer's creation, the following files should be created:
- best_model.skops
- labels.csv
- model_accuracies.csv
- model_info.json
- training_data.csv
This test checks that these exist, and checks some properties of the files."""

import pandas as pd
import json

trainer.X = np.ones((10, 3))
trainer.X = np.random.rand(10, 3)
trainer.y = np.append(np.ones(5), np.zeros(5))

trainer.evaluate_model_config()
trainer_folder = Path(trainer.folder)

assert trainer_folder.is_dir()
assert (trainer_folder / "best_model.skops").is_file()
assert (trainer_folder / "model_accuracies.csv").is_file()
assert (trainer_folder / "model_info.json").is_file()

best_model_path = trainer_folder / "best_model.skops"
model_accuracies_path = trainer_folder / "model_accuracies.csv"
training_data_path = trainer_folder / "training_data.csv"
labels_path = trainer_folder / "labels.csv"
model_info_path = trainer_folder / "model_info.json"

assert (best_model_path).is_file()

model_accuracies = pd.read_csv(model_accuracies_path)
model_accuracies["classifier name"].values[0] == "LogisticRegression"
assert len(model_accuracies) == 1

training_data = pd.read_csv(training_data_path)
assert np.all(np.isclose(training_data.values[:, 1:4], trainer.X, rtol=1e-10))

labels = pd.read_csv(labels_path)
assert np.all(labels.values[:, 1] == trainer.y.astype("float"))

model_info = pd.read_json(model_info_path)

with open(model_info_path) as f:
model_info = json.load(f)

assert set(model_info.keys()) == set(["metric_params", "requirements", "label_conversion"])


def test_train_model_using_two_csvs():
Expand Down

0 comments on commit 830a977

Please sign in to comment.