Skip to content

Commit

Permalink
add test documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
chrishalcrow committed Nov 29, 2024
1 parent 830a977 commit 0eb8751
Showing 1 changed file with 18 additions and 13 deletions.
31 changes: 18 additions & 13 deletions src/spikeinterface/curation/tests/test_train_manual_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,19 +173,17 @@ def test_saved_files(trainer):
assert set(model_info.keys()) == set(["metric_params", "requirements", "label_conversion"])


def test_train_model_using_two_csvs():

metrics_path_1 = make_temp_training_csv()
metrics_path_2 = make_temp_training_csv()
def test_train_model():
"""A simple function test to check that `train_model` doesn't fail with one csv inputs"""

metrics_path = make_temp_training_csv()
folder = tempfile.mkdtemp()
metric_names = ["metric1", "metric2", "metric3"]

trainer = train_model(
mode="csv",
metrics_paths=[metrics_path_1, metrics_path_2],
metrics_paths=[metrics_path],
folder=folder,
labels=[[0, 1, 0, 1, 0, 1, 0, 1, 0, 1], [0, 1, 0, 1, 0, 1, 0, 1, 0, 1]],
labels=[[0, 1, 0, 1, 0, 1, 0, 1, 0, 1]],
metric_names=metric_names,
imputation_strategies=["median"],
scaling_techniques=["standard_scaler"],
Expand All @@ -195,16 +193,21 @@ def test_train_model_using_two_csvs():
assert isinstance(trainer, CurationModelTrainer)


def test_train_model():
def test_train_model_using_two_csvs():
"""Models can be trained using more than one set of training data. This test checks
that `train_model` works with two inputs, from csv files."""

metrics_path_1 = make_temp_training_csv()
metrics_path_2 = make_temp_training_csv()

metrics_path = make_temp_training_csv()
folder = tempfile.mkdtemp()
metric_names = ["metric1", "metric2", "metric3"]

trainer = train_model(
mode="csv",
metrics_paths=[metrics_path],
metrics_paths=[metrics_path_1, metrics_path_2],
folder=folder,
labels=[[0, 1, 0, 1, 0, 1, 0, 1, 0, 1]],
labels=[[0, 1, 0, 1, 0, 1, 0, 1, 0, 1], [0, 1, 0, 1, 0, 1, 0, 1, 0, 1]],
metric_names=metric_names,
imputation_strategies=["median"],
scaling_techniques=["standard_scaler"],
Expand All @@ -215,6 +218,9 @@ def test_train_model():


def test_train_using_two_sorting_analyzers():
"""Models can be trained using more than one set of training data. This test checks
that `train_model` works with two inputs, from sorting analzyers. It also checks that
an error is raised if the sorting_analyzers have different sets of metrics computed."""

sorting_analyzer_1 = make_sorting_analyzer()
sorting_analyzer_1.compute({"quality_metrics": {"metric_names": ["num_spikes", "snr"]}})
Expand All @@ -238,8 +244,7 @@ def test_train_using_two_sorting_analyzers():

assert isinstance(trainer, CurationModelTrainer)

# Xheck that there is an error raised if the metric names are different

# Check that there is an error raised if the metric names are different
sorting_analyzer_2 = make_sorting_analyzer()
sorting_analyzer_2.compute({"quality_metrics": {"metric_names": ["num_spikes"], "delete_existing_metrics": True}})

Expand Down

0 comments on commit 0eb8751

Please sign in to comment.