From 0eb8751db271939ea443e9f38d605e45a1253692 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Fri, 29 Nov 2024 17:26:09 +0000 Subject: [PATCH] add test documentation --- .../tests/test_train_manual_curation.py | 31 +++++++++++-------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/src/spikeinterface/curation/tests/test_train_manual_curation.py b/src/spikeinterface/curation/tests/test_train_manual_curation.py index d6e9c97f55..720a4fbf96 100644 --- a/src/spikeinterface/curation/tests/test_train_manual_curation.py +++ b/src/spikeinterface/curation/tests/test_train_manual_curation.py @@ -173,19 +173,17 @@ def test_saved_files(trainer): assert set(model_info.keys()) == set(["metric_params", "requirements", "label_conversion"]) -def test_train_model_using_two_csvs(): - - metrics_path_1 = make_temp_training_csv() - metrics_path_2 = make_temp_training_csv() +def test_train_model(): + """A simple function test to check that `train_model` doesn't fail with one csv inputs""" + metrics_path = make_temp_training_csv() folder = tempfile.mkdtemp() metric_names = ["metric1", "metric2", "metric3"] - trainer = train_model( mode="csv", - metrics_paths=[metrics_path_1, metrics_path_2], + metrics_paths=[metrics_path], folder=folder, - labels=[[0, 1, 0, 1, 0, 1, 0, 1, 0, 1], [0, 1, 0, 1, 0, 1, 0, 1, 0, 1]], + labels=[[0, 1, 0, 1, 0, 1, 0, 1, 0, 1]], metric_names=metric_names, imputation_strategies=["median"], scaling_techniques=["standard_scaler"], @@ -195,16 +193,21 @@ def test_train_model_using_two_csvs(): assert isinstance(trainer, CurationModelTrainer) -def test_train_model(): +def test_train_model_using_two_csvs(): + """Models can be trained using more than one set of training data. This test checks + that `train_model` works with two inputs, from csv files.""" + + metrics_path_1 = make_temp_training_csv() + metrics_path_2 = make_temp_training_csv() - metrics_path = make_temp_training_csv() folder = tempfile.mkdtemp() metric_names = ["metric1", "metric2", "metric3"] + trainer = train_model( mode="csv", - metrics_paths=[metrics_path], + metrics_paths=[metrics_path_1, metrics_path_2], folder=folder, - labels=[[0, 1, 0, 1, 0, 1, 0, 1, 0, 1]], + labels=[[0, 1, 0, 1, 0, 1, 0, 1, 0, 1], [0, 1, 0, 1, 0, 1, 0, 1, 0, 1]], metric_names=metric_names, imputation_strategies=["median"], scaling_techniques=["standard_scaler"], @@ -215,6 +218,9 @@ def test_train_model(): def test_train_using_two_sorting_analyzers(): + """Models can be trained using more than one set of training data. This test checks + that `train_model` works with two inputs, from sorting analzyers. It also checks that + an error is raised if the sorting_analyzers have different sets of metrics computed.""" sorting_analyzer_1 = make_sorting_analyzer() sorting_analyzer_1.compute({"quality_metrics": {"metric_names": ["num_spikes", "snr"]}}) @@ -238,8 +244,7 @@ def test_train_using_two_sorting_analyzers(): assert isinstance(trainer, CurationModelTrainer) - # Xheck that there is an error raised if the metric names are different - + # Check that there is an error raised if the metric names are different sorting_analyzer_2 = make_sorting_analyzer() sorting_analyzer_2.compute({"quality_metrics": {"metric_names": ["num_spikes"], "delete_existing_metrics": True}})