diff --git a/conftest.py b/conftest.py index 040818320e..c4bac6628a 100644 --- a/conftest.py +++ b/conftest.py @@ -14,16 +14,10 @@ "widgets", "exporters", "sortingcomponents", "generation"] -# define global test folder -def pytest_sessionstart(session): - # setup_stuff - pytest.global_test_folder = Path(__file__).parent / "test_folder" - if pytest.global_test_folder.is_dir(): - shutil.rmtree(pytest.global_test_folder) - pytest.global_test_folder.mkdir() - - for mark_name in mark_names: - (pytest.global_test_folder / mark_name).mkdir() +@pytest.fixture(scope="module") +def create_cache_folder(tmp_path_factory): + cache_folder = tmp_path_factory.mktemp("cache_folder") + return cache_folder def pytest_collection_modifyitems(config, items): """ @@ -45,12 +39,3 @@ def pytest_collection_modifyitems(config, items): item.add_marker("sorters") else: item.add_marker(module) - - - -def pytest_sessionfinish(session, exitstatus): - # teardown_stuff only if tests passed - # We don't delete the test folder in the CI because it was causing problems with the code coverage. - if exitstatus == 0: - if pytest.global_test_folder.is_dir() and not ON_GITHUB: - shutil.rmtree(pytest.global_test_folder) diff --git a/src/spikeinterface/comparison/tests/test_groundtruthstudy.py b/src/spikeinterface/comparison/tests/test_groundtruthstudy.py index b7df085fab..a92d6e9f77 100644 --- a/src/spikeinterface/comparison/tests/test_groundtruthstudy.py +++ b/src/spikeinterface/comparison/tests/test_groundtruthstudy.py @@ -7,19 +7,13 @@ from spikeinterface.comparison import GroundTruthStudy -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "comparison" -else: - cache_folder = Path("cache_folder") / "comparison" - cache_folder.mkdir(exist_ok=True, parents=True) - -study_folder = cache_folder / "test_groundtruthstudy/" - - -def setup_module(): +@pytest.fixture(scope="module") +def setup_module(tmp_path_factory): + study_folder = tmp_path_factory.mktemp("study_folder") if study_folder.is_dir(): shutil.rmtree(study_folder) create_a_study(study_folder) + return study_folder def simple_preprocess(rec): @@ -74,7 +68,8 @@ def create_a_study(study_folder): # print(study) -def test_GroundTruthStudy(): +def test_GroundTruthStudy(setup_module): + study_folder = setup_module study = GroundTruthStudy(study_folder) print(study) diff --git a/src/spikeinterface/comparison/tests/test_hybrid.py b/src/spikeinterface/comparison/tests/test_hybrid.py index 8c392f7687..ce409ca778 100644 --- a/src/spikeinterface/comparison/tests/test_hybrid.py +++ b/src/spikeinterface/comparison/tests/test_hybrid.py @@ -11,13 +11,9 @@ from spikeinterface.preprocessing import bandpass_filter -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "comparison" / "hybrid" -else: - cache_folder = Path("cache_folder") / "comparison" / "hybrid" - - -def setup_module(): +@pytest.fixture(scope="module") +def setup_module(tmp_path_factory): + cache_folder = tmp_path_factory.mktemp("cache_folder") if cache_folder.is_dir(): shutil.rmtree(cache_folder) cache_folder.mkdir(parents=True, exist_ok=True) @@ -31,9 +27,11 @@ def setup_module(): wvf_extractor = extract_waveforms( recording, sorting, folder=cache_folder / "wvf_extractor", ms_before=10.0, ms_after=10.0 ) + return cache_folder -def test_hybrid_units_recording(): +def test_hybrid_units_recording(setup_module): + cache_folder = setup_module wvf_extractor = load_waveforms(cache_folder / "wvf_extractor") print(wvf_extractor) print(wvf_extractor.sorting_analyzer) @@ -63,7 +61,8 @@ def test_hybrid_units_recording(): check_recordings_equal(hybrid_units_recording, saved_2job, return_scaled=False) -def test_hybrid_spikes_recording(): +def test_hybrid_spikes_recording(setup_module): + cache_folder = setup_module wvf_extractor = load_waveforms(cache_folder / "wvf_extractor") recording = wvf_extractor.recording sorting = wvf_extractor.sorting diff --git a/src/spikeinterface/comparison/tests/test_multisortingcomparison.py b/src/spikeinterface/comparison/tests/test_multisortingcomparison.py index f39b8cd890..9ea8ba3e80 100644 --- a/src/spikeinterface/comparison/tests/test_multisortingcomparison.py +++ b/src/spikeinterface/comparison/tests/test_multisortingcomparison.py @@ -1,6 +1,4 @@ import shutil -import pytest -from pathlib import Path import pytest import numpy as np @@ -9,18 +7,14 @@ from spikeinterface.extractors import NumpySorting from spikeinterface.comparison import compare_multiple_sorters, MultiSortingComparison -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "comparison" -else: - cache_folder = Path("cache_folder") / "comparison" - - -multicomparison_folder = cache_folder / "saved_multisorting_comparison" - -def setup_module(): +@pytest.fixture(scope="module") +def setup_module(tmp_path_factory): + cache_folder = tmp_path_factory.mktemp("cache_folder") + multicomparison_folder = cache_folder / "saved_multisorting_comparison" if multicomparison_folder.is_dir(): shutil.rmtree(multicomparison_folder) + return multicomparison_folder def make_sorting(times1, labels1, times2, labels2, times3, labels3): @@ -34,7 +28,8 @@ def make_sorting(times1, labels1, times2, labels2, times3, labels3): return sorting1, sorting2, sorting3 -def test_compare_multiple_sorters(): +def test_compare_multiple_sorters(setup_module): + multicomparison_folder = setup_module # simple match sorting1, sorting2, sorting3 = make_sorting( [100, 200, 300, 400, 500, 600, 700, 800, 900], diff --git a/src/spikeinterface/comparison/tests/test_templatecomparison.py b/src/spikeinterface/comparison/tests/test_templatecomparison.py index f4b3a14f0d..871bdeaed3 100644 --- a/src/spikeinterface/comparison/tests/test_templatecomparison.py +++ b/src/spikeinterface/comparison/tests/test_templatecomparison.py @@ -1,21 +1,11 @@ import shutil import pytest -from pathlib import Path import numpy as np from spikeinterface.core import create_sorting_analyzer, generate_ground_truth_recording from spikeinterface.comparison import compare_templates, compare_multiple_templates -# if hasattr(pytest, "global_test_folder"): -# cache_folder = pytest.global_test_folder / "comparison" -# else: -# cache_folder = Path("cache_folder") / "comparison" - - -# test_dir = cache_folder / "temp_comp_test" - - # def setup_module(): # if test_dir.is_dir(): # shutil.rmtree(test_dir) diff --git a/src/spikeinterface/core/tests/test_analyzer_extension_core.py b/src/spikeinterface/core/tests/test_analyzer_extension_core.py index 8991c959ad..b4d96a3391 100644 --- a/src/spikeinterface/core/tests/test_analyzer_extension_core.py +++ b/src/spikeinterface/core/tests/test_analyzer_extension_core.py @@ -1,5 +1,4 @@ import pytest -from pathlib import Path import shutil @@ -11,13 +10,8 @@ import numpy as np -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "core" -else: - cache_folder = Path("cache_folder") / "core" - -def get_sorting_analyzer(format="memory", sparse=True): +def get_sorting_analyzer(cache_folder, format="memory", sparse=True): recording, sorting = generate_ground_truth_recording( durations=[30.0], sampling_frequency=16000.0, @@ -53,7 +47,7 @@ def get_sorting_analyzer(format="memory", sparse=True): return sorting_analyzer -def _check_result_extension(sorting_analyzer, extension_name): +def _check_result_extension(sorting_analyzer, extension_name, cache_folder): # select unit_ids to several format for format in ("memory", "binary_folder", "zarr"): # for format in ("memory", ): @@ -83,39 +77,42 @@ def _check_result_extension(sorting_analyzer, extension_name): False, ], ) -def test_ComputeRandomSpikes(format, sparse): - sorting_analyzer = get_sorting_analyzer(format=format, sparse=sparse) +def test_ComputeRandomSpikes(format, sparse, create_cache_folder): + cache_folder = create_cache_folder + sorting_analyzer = get_sorting_analyzer(cache_folder, format=format, sparse=sparse) ext = sorting_analyzer.compute("random_spikes", max_spikes_per_unit=10, seed=2205) indices = ext.data["random_spikes_indices"] assert indices.size == 10 * sorting_analyzer.sorting.unit_ids.size - _check_result_extension(sorting_analyzer, "random_spikes") + _check_result_extension(sorting_analyzer, "random_spikes", cache_folder) sorting_analyzer.delete_extension("random_spikes") ext = sorting_analyzer.compute("random_spikes", method="all") indices = ext.data["random_spikes_indices"] assert indices.size == len(sorting_analyzer.sorting.to_spike_vector()) - _check_result_extension(sorting_analyzer, "random_spikes") + _check_result_extension(sorting_analyzer, "random_spikes", cache_folder) @pytest.mark.parametrize("format", ["memory", "binary_folder", "zarr"]) @pytest.mark.parametrize("sparse", [True, False]) -def test_ComputeWaveforms(format, sparse): - sorting_analyzer = get_sorting_analyzer(format=format, sparse=sparse) +def test_ComputeWaveforms(format, sparse, create_cache_folder): + cache_folder = create_cache_folder + sorting_analyzer = get_sorting_analyzer(cache_folder, format=format, sparse=sparse) job_kwargs = dict(n_jobs=2, chunk_duration="1s", progress_bar=True) sorting_analyzer.compute("random_spikes", max_spikes_per_unit=50, seed=2205) ext = sorting_analyzer.compute("waveforms", **job_kwargs) wfs = ext.data["waveforms"] - _check_result_extension(sorting_analyzer, "waveforms") + _check_result_extension(sorting_analyzer, "waveforms", cache_folder) @pytest.mark.parametrize("format", ["memory", "binary_folder", "zarr"]) @pytest.mark.parametrize("sparse", [True, False]) -def test_ComputeTemplates(format, sparse): - sorting_analyzer = get_sorting_analyzer(format=format, sparse=sparse) +def test_ComputeTemplates(format, sparse, create_cache_folder): + cache_folder = create_cache_folder + sorting_analyzer = get_sorting_analyzer(cache_folder, format=format, sparse=sparse) sorting_analyzer.compute("random_spikes", max_spikes_per_unit=50, seed=2205) @@ -187,13 +184,14 @@ def test_ComputeTemplates(format, sparse): # ax.legend() # plt.show() - _check_result_extension(sorting_analyzer, "templates") + _check_result_extension(sorting_analyzer, "templates", cache_folder) @pytest.mark.parametrize("format", ["memory", "binary_folder", "zarr"]) @pytest.mark.parametrize("sparse", [True, False]) -def test_ComputeNoiseLevels(format, sparse): - sorting_analyzer = get_sorting_analyzer(format=format, sparse=sparse) +def test_ComputeNoiseLevels(format, sparse, create_cache_folder): + cache_folder = create_cache_folder + sorting_analyzer = get_sorting_analyzer(cache_folder, format=format, sparse=sparse) sorting_analyzer.compute("noise_levels") print(sorting_analyzer) @@ -212,8 +210,9 @@ def test_get_children_dependencies(): assert rs_children.index("waveforms") < rs_children.index("templates") -def test_delete_on_recompute(): - sorting_analyzer = get_sorting_analyzer(format="memory", sparse=False) +def test_delete_on_recompute(create_cache_folder): + cache_folder = create_cache_folder + sorting_analyzer = get_sorting_analyzer(cache_folder, format="memory", sparse=False) sorting_analyzer.compute("random_spikes") sorting_analyzer.compute("waveforms") sorting_analyzer.compute("templates") @@ -224,8 +223,9 @@ def test_delete_on_recompute(): assert sorting_analyzer.get_extension("waveforms") is None -def test_compute_several(): - sorting_analyzer = get_sorting_analyzer(format="memory", sparse=False) +def test_compute_several(create_cache_folder): + cache_folder = create_cache_folder + sorting_analyzer = get_sorting_analyzer(cache_folder, format="memory", sparse=False) # should raise an error since waveforms depends on random_spikes, which isn't calculated with pytest.raises(AssertionError): diff --git a/src/spikeinterface/core/tests/test_baserecording.py b/src/spikeinterface/core/tests/test_baserecording.py index 48dd11d996..eb6cf7ac12 100644 --- a/src/spikeinterface/core/tests/test_baserecording.py +++ b/src/spikeinterface/core/tests/test_baserecording.py @@ -18,14 +18,9 @@ from spikeinterface.core import generate_recording -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "core" -else: - cache_folder = Path("cache_folder") / "core" - cache_folder.mkdir(exist_ok=True, parents=True) - -def test_BaseRecording(): +def test_BaseRecording(create_cache_folder): + cache_folder = create_cache_folder num_seg = 2 num_chan = 3 num_samples = 30 diff --git a/src/spikeinterface/core/tests/test_basesnippets.py b/src/spikeinterface/core/tests/test_basesnippets.py index cb57b3861e..64f7f76819 100644 --- a/src/spikeinterface/core/tests/test_basesnippets.py +++ b/src/spikeinterface/core/tests/test_basesnippets.py @@ -14,14 +14,9 @@ from spikeinterface.core.npysnippetsextractor import NpySnippetsExtractor from spikeinterface.core.base import BaseExtractor -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "core" -else: - cache_folder = Path("cache_folder") / "core" - cache_folder.mkdir(exist_ok=True, parents=True) - -def test_BaseSnippets(): +def test_BaseSnippets(create_cache_folder): + cache_folder = create_cache_folder duration = [4, 3] num_channels = 3 nbefore = 20 diff --git a/src/spikeinterface/core/tests/test_basesorting.py b/src/spikeinterface/core/tests/test_basesorting.py index 331c65b46a..42fdf52eb1 100644 --- a/src/spikeinterface/core/tests/test_basesorting.py +++ b/src/spikeinterface/core/tests/test_basesorting.py @@ -25,13 +25,9 @@ from spikeinterface.core.testing import check_sorted_arrays_equal, check_sortings_equal from spikeinterface.core.generate import generate_sorting -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "core" -else: - cache_folder = Path("cache_folder") / "core" - -def test_BaseSorting(): +def test_BaseSorting(create_cache_folder): + cache_folder = create_cache_folder num_seg = 2 file_path = cache_folder / "test_BaseSorting.npz" file_path.parent.mkdir(exist_ok=True) diff --git a/src/spikeinterface/core/tests/test_binaryfolder.py b/src/spikeinterface/core/tests/test_binaryfolder.py index 4d38d1bc04..1e64afe4e4 100644 --- a/src/spikeinterface/core/tests/test_binaryfolder.py +++ b/src/spikeinterface/core/tests/test_binaryfolder.py @@ -9,13 +9,8 @@ from spikeinterface.core import generate_recording -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "core" -else: - cache_folder = Path("cache_folder") / "core" - - -def test_BinaryFolderRecording(): +def test_BinaryFolderRecording(create_cache_folder): + cache_folder = create_cache_folder rec = generate_recording(num_channels=10, durations=[2.0, 2.0]) folder = cache_folder / "binary_folder_1" diff --git a/src/spikeinterface/core/tests/test_binaryrecordingextractor.py b/src/spikeinterface/core/tests/test_binaryrecordingextractor.py index fb4c3ee3c4..61af8f322d 100644 --- a/src/spikeinterface/core/tests/test_binaryrecordingextractor.py +++ b/src/spikeinterface/core/tests/test_binaryrecordingextractor.py @@ -1,17 +1,12 @@ import pytest import numpy as np -from pathlib import Path from spikeinterface.core import BinaryRecordingExtractor from spikeinterface.core.numpyextractors import NumpyRecording -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "core" -else: - cache_folder = Path("cache_folder") / "core" - -def test_BinaryRecordingExtractor(): +def test_BinaryRecordingExtractor(create_cache_folder): + cache_folder = create_cache_folder num_seg = 2 num_channels = 3 num_samples = 30 diff --git a/src/spikeinterface/core/tests/test_channelslicerecording.py b/src/spikeinterface/core/tests/test_channelslicerecording.py index e2e4dfdb2c..5d9354de9b 100644 --- a/src/spikeinterface/core/tests/test_channelslicerecording.py +++ b/src/spikeinterface/core/tests/test_channelslicerecording.py @@ -10,11 +10,8 @@ from spikeinterface.core.generate import generate_recording -def test_ChannelSliceRecording(): - if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "core" - else: - cache_folder = Path("cache_folder") / "core" +def test_ChannelSliceRecording(create_cache_folder): + cache_folder = create_cache_folder num_seg = 2 num_chan = 3 diff --git a/src/spikeinterface/core/tests/test_core_tools.py b/src/spikeinterface/core/tests/test_core_tools.py index 9f0c83189f..8e00dcb779 100644 --- a/src/spikeinterface/core/tests/test_core_tools.py +++ b/src/spikeinterface/core/tests/test_core_tools.py @@ -15,12 +15,6 @@ ) -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "core" -else: - cache_folder = Path("cache_folder") / "core" - - def test_add_suffix(): # first case - no dot provided before extension file_path = "testpath" diff --git a/src/spikeinterface/core/tests/test_globals.py b/src/spikeinterface/core/tests/test_globals.py index 668cdb980f..9677378fc5 100644 --- a/src/spikeinterface/core/tests/test_globals.py +++ b/src/spikeinterface/core/tests/test_globals.py @@ -14,13 +14,9 @@ ) from spikeinterface.core.job_tools import fix_job_kwargs -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "core" -else: - cache_folder = Path("cache_folder") / "core" - -def test_global_dataset_folder(): +def test_global_dataset_folder(create_cache_folder): + cache_folder = create_cache_folder dataset_folder = get_global_dataset_folder() assert dataset_folder.is_dir() new_dataset_folder = cache_folder / "dataset_folder" @@ -29,7 +25,8 @@ def test_global_dataset_folder(): assert new_dataset_folder.is_dir() -def test_global_tmp_folder(): +def test_global_tmp_folder(create_cache_folder): + cache_folder = create_cache_folder tmp_folder = get_global_tmp_folder() assert tmp_folder.is_dir() new_tmp_folder = cache_folder / "tmp_folder" diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index 76ba5d041b..03acc9fed1 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -18,12 +18,6 @@ ) -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "core" -else: - cache_folder = Path("cache_folder") / "core" - - class AmplitudeExtractionNode(PipelineNode): def __init__(self, recording, parents=None, return_output=True, param0=5.5): PipelineNode.__init__(self, recording, parents=parents, return_output=return_output) @@ -68,7 +62,14 @@ def compute(self, traces, peaks, waveforms): return rms_by_channels -def test_run_node_pipeline(): +@pytest.fixture(scope="module") +def cache_folder_creation(tmp_path_factory): + cache_folder = tmp_path_factory.mktemp("cache_folder") + return cache_folder + + +def test_run_node_pipeline(cache_folder_creation): + cache_folder = cache_folder_creation recording, sorting = generate_ground_truth_recording(num_channels=10, num_units=10, durations=[10.0]) # job_kwargs = dict(chunk_duration="0.5s", n_jobs=2, progress_bar=False) diff --git a/src/spikeinterface/core/tests/test_noise_levels_propagation.py b/src/spikeinterface/core/tests/test_noise_levels_propagation.py index 6f1b46bd33..421f709d06 100644 --- a/src/spikeinterface/core/tests/test_noise_levels_propagation.py +++ b/src/spikeinterface/core/tests/test_noise_levels_propagation.py @@ -7,13 +7,6 @@ import numpy as np -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "preprocessing" -else: - cache_folder = Path("cache_folder") / "preprocessing" - -set_global_tmp_folder(cache_folder) - def test_skip_noise_levels_propagation(): rec = generate_recording(durations=[5.0], num_channels=4) diff --git a/src/spikeinterface/core/tests/test_npyfoldersnippets.py b/src/spikeinterface/core/tests/test_npyfoldersnippets.py index 12edef9d15..c0d7f303bf 100644 --- a/src/spikeinterface/core/tests/test_npyfoldersnippets.py +++ b/src/spikeinterface/core/tests/test_npyfoldersnippets.py @@ -7,13 +7,15 @@ from spikeinterface.core import generate_snippets -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "core" -else: - cache_folder = Path("cache_folder") / "core" +@pytest.fixture(scope="module") +def cache_folder_creation(tmp_path_factory): + cache_folder = tmp_path_factory.mktemp("cache_folder") + return cache_folder -def test_NpyFolderSnippets(): +def test_NpyFolderSnippets(cache_folder_creation): + + cache_folder = cache_folder_creation snippets, _ = generate_snippets(num_channels=10, durations=[2.0, 1.0]) folder = cache_folder / "npy_folder_1" diff --git a/src/spikeinterface/core/tests/test_npysnippetsextractor.py b/src/spikeinterface/core/tests/test_npysnippetsextractor.py index 8a44f657fc..c3fbfcd885 100644 --- a/src/spikeinterface/core/tests/test_npysnippetsextractor.py +++ b/src/spikeinterface/core/tests/test_npysnippetsextractor.py @@ -4,13 +4,9 @@ from spikeinterface.core import NpySnippetsExtractor from spikeinterface.core import generate_snippets -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "core" -else: - cache_folder = Path("cache_folder") / "core" - -def test_NpySnippetsExtractor(): +def test_NpySnippetsExtractor(create_cache_folder): + cache_folder = create_cache_folder segment_durations = [2, 5] sampling_frequency = 30000 file_path = [cache_folder / f"test_NpySnippetsExtractor_{i}.npy" for i in range(len(segment_durations))] diff --git a/src/spikeinterface/core/tests/test_npzsortingextractor.py b/src/spikeinterface/core/tests/test_npzsortingextractor.py index a9c34d97df..4d84ee9a4d 100644 --- a/src/spikeinterface/core/tests/test_npzsortingextractor.py +++ b/src/spikeinterface/core/tests/test_npzsortingextractor.py @@ -4,13 +4,9 @@ from spikeinterface.core import NpzSortingExtractor from spikeinterface.core import create_sorting_npz -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "core" -else: - cache_folder = Path("cache_folder") / "core" - -def test_NpzSortingExtractor(): +def test_NpzSortingExtractor(create_cache_folder): + cache_folder = create_cache_folder num_seg = 2 file_path = cache_folder / "test_NpzSortingExtractor.npz" diff --git a/src/spikeinterface/core/tests/test_numpy_extractors.py b/src/spikeinterface/core/tests/test_numpy_extractors.py index c694026918..fecafb8989 100644 --- a/src/spikeinterface/core/tests/test_numpy_extractors.py +++ b/src/spikeinterface/core/tests/test_numpy_extractors.py @@ -1,6 +1,3 @@ -import shutil -from pathlib import Path - import pytest import numpy as np @@ -19,13 +16,9 @@ from spikeinterface.core.basesorting import minimum_spike_dtype -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "core" -else: - cache_folder = Path("cache_folder") / "core" - -def test_NumpyRecording(): +@pytest.fixture(scope="module") +def setup_NumpyRecording(tmp_path_factory): sampling_frequency = 30000 timeseries_list = [] for seg_index in range(3): @@ -36,8 +29,9 @@ def test_NumpyRecording(): # print(rec) times1 = rec.get_times(1) - + cache_folder = tmp_path_factory.mktemp("cache_folder") rec.save(folder=cache_folder / "test_NumpyRecording") + return cache_folder def test_SharedMemoryRecording(): @@ -57,7 +51,7 @@ def test_SharedMemoryRecording(): del rec -def test_NumpySorting(): +def test_NumpySorting(setup_NumpyRecording): sampling_frequency = 30000 # empty @@ -82,6 +76,9 @@ def test_NumpySorting(): # from other extracrtor num_seg = 2 + + cache_folder = setup_NumpyRecording + file_path = cache_folder / "test_NpzSortingExtractor.npz" create_sorting_npz(num_seg, file_path) other_sorting = NpzSortingExtractor(file_path) diff --git a/src/spikeinterface/core/tests/test_sorting_folder.py b/src/spikeinterface/core/tests/test_sorting_folder.py index 359e3ee7fc..0fe19534ec 100644 --- a/src/spikeinterface/core/tests/test_sorting_folder.py +++ b/src/spikeinterface/core/tests/test_sorting_folder.py @@ -1,6 +1,4 @@ import pytest - -from pathlib import Path import shutil import numpy as np @@ -9,13 +7,9 @@ from spikeinterface.core import generate_sorting from spikeinterface.core.testing import check_sorted_arrays_equal, check_sortings_equal -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "core" -else: - cache_folder = Path("cache_folder") / "core" - -def test_NumpyFolderSorting(): +def test_NumpyFolderSorting(create_cache_folder): + cache_folder = create_cache_folder sorting = generate_sorting(seed=42) folder = cache_folder / "numpy_sorting_1" @@ -33,7 +27,8 @@ def test_NumpyFolderSorting(): ) -def test_NpzFolderSorting(): +def test_NpzFolderSorting(create_cache_folder): + cache_folder = create_cache_folder sorting = generate_sorting(seed=42) folder = cache_folder / "npz_folder_sorting_1" diff --git a/src/spikeinterface/core/tests/test_time_handling.py b/src/spikeinterface/core/tests/test_time_handling.py index 8bb8778c76..487a893096 100644 --- a/src/spikeinterface/core/tests/test_time_handling.py +++ b/src/spikeinterface/core/tests/test_time_handling.py @@ -1,17 +1,11 @@ import pytest -from pathlib import Path import numpy as np from spikeinterface.core import generate_recording, generate_sorting -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "core" -else: - cache_folder = Path("cache_folder") / "core" - - -def test_time_handling(): +def test_time_handling(create_cache_folder): + cache_folder = create_cache_folder durations = [[10], [10, 5]] # test multi-segment diff --git a/src/spikeinterface/core/tests/test_unitsaggregationsorting.py b/src/spikeinterface/core/tests/test_unitsaggregationsorting.py index 814e75af3d..b6cb479c7d 100644 --- a/src/spikeinterface/core/tests/test_unitsaggregationsorting.py +++ b/src/spikeinterface/core/tests/test_unitsaggregationsorting.py @@ -1,6 +1,5 @@ import pytest import numpy as np -from pathlib import Path from spikeinterface.core import aggregate_units @@ -8,13 +7,9 @@ from spikeinterface.core import create_sorting_npz -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "core" -else: - cache_folder = Path("cache_folder") / "core" +def test_unitsaggregationsorting(create_cache_folder): + cache_folder = create_cache_folder - -def test_unitsaggregationsorting(): num_seg = 2 file_path = cache_folder / "test_BaseSorting.npz" diff --git a/src/spikeinterface/core/tests/test_waveform_tools.py b/src/spikeinterface/core/tests/test_waveform_tools.py index 25396841cc..845eaf1310 100644 --- a/src/spikeinterface/core/tests/test_waveform_tools.py +++ b/src/spikeinterface/core/tests/test_waveform_tools.py @@ -15,12 +15,6 @@ ) -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "core" -else: - cache_folder = Path("cache_folder") / "core" - - def _check_all_wf_equal(list_wfs_arrays): wfs_arrays0 = list_wfs_arrays[0] for i, wfs_arrays in enumerate(list_wfs_arrays): @@ -41,7 +35,8 @@ def get_dataset(): return recording, sorting -def test_waveform_tools(): +def test_waveform_tools(create_cache_folder): + cache_folder = create_cache_folder # durations = [30, 40] # sampling_frequency = 30000.0 diff --git a/src/spikeinterface/core/tests/test_waveforms_extractor_backwards_compatibility.py b/src/spikeinterface/core/tests/test_waveforms_extractor_backwards_compatibility.py index 9688dc7825..0157965daf 100644 --- a/src/spikeinterface/core/tests/test_waveforms_extractor_backwards_compatibility.py +++ b/src/spikeinterface/core/tests/test_waveforms_extractor_backwards_compatibility.py @@ -16,12 +16,6 @@ from spikeinterface.core import extract_waveforms as old_extract_waveforms -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "core" -else: - cache_folder = Path("cache_folder") / "core" - - def get_dataset(): recording, sorting = generate_ground_truth_recording( durations=[30.0, 20.0], @@ -45,7 +39,8 @@ def get_dataset(): return recording, sorting -def test_extract_waveforms(): +def test_extract_waveforms(create_cache_folder): + cache_folder = create_cache_folder recording, sorting = get_dataset() folder = cache_folder / "old_waveforms_extractor" diff --git a/src/spikeinterface/curation/tests/common.py b/src/spikeinterface/curation/tests/common.py index 5a288a35c8..9cd20f4bfc 100644 --- a/src/spikeinterface/curation/tests/common.py +++ b/src/spikeinterface/curation/tests/common.py @@ -1,17 +1,10 @@ from __future__ import annotations import pytest -from pathlib import Path from spikeinterface.core import generate_ground_truth_recording, create_sorting_analyzer from spikeinterface.qualitymetrics import compute_quality_metrics -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "curation" -else: - cache_folder = Path("cache_folder") / "curation" - - job_kwargs = dict(n_jobs=-1) diff --git a/src/spikeinterface/curation/tests/test_auto_merge.py b/src/spikeinterface/curation/tests/test_auto_merge.py index bbf861dac9..93c302f1f6 100644 --- a/src/spikeinterface/curation/tests/test_auto_merge.py +++ b/src/spikeinterface/curation/tests/test_auto_merge.py @@ -12,12 +12,6 @@ from spikeinterface.curation.tests.common import make_sorting_analyzer, sorting_analyzer_for_curation -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "curation" -else: - cache_folder = Path("cache_folder") / "curation" - - def test_get_auto_merge_list(sorting_analyzer_for_curation): sorting = sorting_analyzer_for_curation.sorting diff --git a/src/spikeinterface/curation/tests/test_sortingview_curation.py b/src/spikeinterface/curation/tests/test_sortingview_curation.py index 8f9e3e570c..00721ff34d 100644 --- a/src/spikeinterface/curation/tests/test_sortingview_curation.py +++ b/src/spikeinterface/curation/tests/test_sortingview_curation.py @@ -16,19 +16,11 @@ ) from spikeinterface.curation import apply_sortingview_curation -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "curation" -else: - cache_folder = Path("cache_folder") / "curation" - parent_folder = Path(__file__).parent ON_GITHUB = bool(os.getenv("GITHUB_ACTIONS")) KACHERY_CLOUD_SET = bool(os.getenv("KACHERY_CLOUD_CLIENT_ID")) and bool(os.getenv("KACHERY_CLOUD_PRIVATE_KEY")) -set_global_tmp_folder(cache_folder) - - # this needs to be run only once: if we want to regenerate we need to start with sorting result # TODO : regenerate the # def generate_sortingview_curation_dataset(): diff --git a/src/spikeinterface/exporters/tests/common.py b/src/spikeinterface/exporters/tests/common.py index a6fc2abf99..78a9c82860 100644 --- a/src/spikeinterface/exporters/tests/common.py +++ b/src/spikeinterface/exporters/tests/common.py @@ -5,11 +5,6 @@ from spikeinterface.core import generate_ground_truth_recording, create_sorting_analyzer, compute_sparsity -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "exporters" -else: - cache_folder = Path("cache_folder") / "exporters" - def make_sorting_analyzer(sparse=True, with_group=False): recording, sorting = generate_ground_truth_recording( diff --git a/src/spikeinterface/exporters/tests/test_export_to_phy.py b/src/spikeinterface/exporters/tests/test_export_to_phy.py index 18ba15b975..47294b3cf7 100644 --- a/src/spikeinterface/exporters/tests/test_export_to_phy.py +++ b/src/spikeinterface/exporters/tests/test_export_to_phy.py @@ -1,25 +1,20 @@ -import pytest import shutil -from pathlib import Path import numpy as np -from spikeinterface.postprocessing import compute_principal_components - -from spikeinterface.core import compute_sparsity from spikeinterface.exporters import export_to_phy from spikeinterface.exporters.tests.common import ( - cache_folder, make_sorting_analyzer, + sorting_analyzer_dense_for_export, sorting_analyzer_sparse_for_export, sorting_analyzer_with_group_for_export, - sorting_analyzer_dense_for_export, ) -def test_export_to_phy_dense(sorting_analyzer_dense_for_export): +def test_export_to_phy_dense(sorting_analyzer_dense_for_export, create_cache_folder): + cache_folder = create_cache_folder output_folder1 = cache_folder / "phy_output_dense" for f in (output_folder1,): if f.is_dir(): @@ -38,7 +33,8 @@ def test_export_to_phy_dense(sorting_analyzer_dense_for_export): ) -def test_export_to_phy_sparse(sorting_analyzer_sparse_for_export): +def test_export_to_phy_sparse(sorting_analyzer_sparse_for_export, create_cache_folder): + cache_folder = create_cache_folder output_folder1 = cache_folder / "phy_output_1" output_folder2 = cache_folder / "phy_output_2" for f in (output_folder1, output_folder2): @@ -70,7 +66,8 @@ def test_export_to_phy_sparse(sorting_analyzer_sparse_for_export): ) -def test_export_to_phy_by_property(sorting_analyzer_with_group_for_export): +def test_export_to_phy_by_property(sorting_analyzer_with_group_for_export, create_cache_folder): + cache_folder = create_cache_folder output_folder = cache_folder / "phy_output_property" for f in (output_folder,): diff --git a/src/spikeinterface/exporters/tests/test_report.py b/src/spikeinterface/exporters/tests/test_report.py index cd000bc077..c712fcafb1 100644 --- a/src/spikeinterface/exporters/tests/test_report.py +++ b/src/spikeinterface/exporters/tests/test_report.py @@ -1,18 +1,17 @@ -from pathlib import Path import shutil -import pytest - from spikeinterface.exporters import export_report from spikeinterface.exporters.tests.common import ( - cache_folder, make_sorting_analyzer, + sorting_analyzer_dense_for_export, sorting_analyzer_sparse_for_export, + sorting_analyzer_with_group_for_export, ) -def test_export_report(sorting_analyzer_sparse_for_export): +def test_export_report(sorting_analyzer_sparse_for_export, create_cache_folder): + cache_folder = create_cache_folder report_folder = cache_folder / "report" if report_folder.exists(): shutil.rmtree(report_folder) diff --git a/src/spikeinterface/extractors/tests/test_mdaextractors.py b/src/spikeinterface/extractors/tests/test_mdaextractors.py index 6440e575d5..0ef6697c6c 100644 --- a/src/spikeinterface/extractors/tests/test_mdaextractors.py +++ b/src/spikeinterface/extractors/tests/test_mdaextractors.py @@ -4,13 +4,9 @@ from spikeinterface.core import generate_ground_truth_recording from spikeinterface.extractors import MdaRecordingExtractor, MdaSortingExtractor -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "extractors" -else: - cache_folder = Path("cache_folder") / "extractors" - -def test_mda_extractors(): +def test_mda_extractors(create_cache_folder): + cache_folder = create_cache_folder rec, sort = generate_ground_truth_recording(durations=[10.0], num_units=10) MdaRecordingExtractor.write_recording(rec, cache_folder / "mdatest") diff --git a/src/spikeinterface/extractors/tests/test_shybridextractors.py b/src/spikeinterface/extractors/tests/test_shybridextractors.py index a0164fd119..221e1bfc2d 100644 --- a/src/spikeinterface/extractors/tests/test_shybridextractors.py +++ b/src/spikeinterface/extractors/tests/test_shybridextractors.py @@ -1,17 +1,13 @@ import pytest -from pathlib import Path + from spikeinterface.core import generate_ground_truth_recording from spikeinterface.core.testing import check_recordings_equal, check_sortings_equal from spikeinterface.extractors import SHYBRIDRecordingExtractor, SHYBRIDSortingExtractor -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "extractors" -else: - cache_folder = Path("cache_folder") / "extractors" - @pytest.mark.skipif(True, reason="SHYBRID only tested locally") -def test_shybrid_extractors(): +def test_shybrid_extractors(create_cache_folder): + cache_folder = create_cache_folder rec, sort = generate_ground_truth_recording(durations=[10.0], num_units=10) SHYBRIDSortingExtractor.write_sorting(sort, cache_folder / "shybridtest") diff --git a/src/spikeinterface/generation/tests/test_drift_tools.py b/src/spikeinterface/generation/tests/test_drift_tools.py index e64e64ffda..8a4837100e 100644 --- a/src/spikeinterface/generation/tests/test_drift_tools.py +++ b/src/spikeinterface/generation/tests/test_drift_tools.py @@ -1,6 +1,4 @@ -import pytest import numpy as np -from pathlib import Path import shutil from spikeinterface.generation import ( @@ -16,12 +14,6 @@ from probeinterface import generate_multi_columns_probe -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "generation" -else: - cache_folder = Path("cache_folder") / "generation" - - def make_some_templates(): probe = generate_multi_columns_probe( num_columns=12, @@ -121,7 +113,8 @@ def test_DriftingTemplates(): ) -def test_InjectDriftingTemplatesRecording(): +def test_InjectDriftingTemplatesRecording(create_cache_folder): + cache_folder = create_cache_folder templates = make_some_templates() probe = templates.probe diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index 605997f5f6..281782745a 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -3,21 +3,12 @@ import pytest import numpy as np import shutil -from pathlib import Path from spikeinterface.core import generate_ground_truth_recording from spikeinterface.core import create_sorting_analyzer from spikeinterface.core import estimate_sparsity -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "postprocessing" -else: - cache_folder = Path("cache_folder") / "postprocessing" - -cache_folder.mkdir(exist_ok=True, parents=True) - - def get_dataset(): recording, sorting = generate_ground_truth_recording( durations=[15.0, 5.0], @@ -41,24 +32,6 @@ def get_dataset(): return recording, sorting -def get_sorting_analyzer(recording, sorting, format="memory", sparsity=None, name=""): - sparse = sparsity is not None - if format == "memory": - folder = None - elif format == "binary_folder": - folder = cache_folder / f"test_{name}_sparse{sparse}_{format}" - elif format == "zarr": - folder = cache_folder / f"test_{name}_sparse{sparse}_{format}.zarr" - if folder and folder.exists(): - shutil.rmtree(folder) - - sorting_analyzer = create_sorting_analyzer( - sorting, recording, format=format, folder=folder, sparse=False, sparsity=sparsity - ) - - return sorting_analyzer - - class AnalyzerExtensionCommonTestSuite: """ Common tests with class approach to compute extension on several cases (3 format x 2 sparsity) @@ -83,10 +56,31 @@ def setUpClass(cls): def extension_name(self): return self.extension_class.extension_name + @pytest.fixture(autouse=True) + def create_cache_folder(self, tmp_path_factory): + self.cache_folder = tmp_path_factory.mktemp("cache_folder") + + def get_sorting_analyzer(self, recording, sorting, format="memory", sparsity=None, name=""): + sparse = sparsity is not None + if format == "memory": + folder = None + elif format == "binary_folder": + folder = self.cache_folder / f"test_{name}_sparse{sparse}_{format}" + elif format == "zarr": + folder = self.cache_folder / f"test_{name}_sparse{sparse}_{format}.zarr" + if folder and folder.exists(): + shutil.rmtree(folder) + + sorting_analyzer = create_sorting_analyzer( + sorting, recording, format=format, folder=folder, sparse=False, sparsity=sparsity + ) + + return sorting_analyzer + def _prepare_sorting_analyzer(self, format, sparse): # prepare a SortingAnalyzer object with depencies already computed sparsity_ = self.sparsity if sparse else None - sorting_analyzer = get_sorting_analyzer( + sorting_analyzer = self.get_sorting_analyzer( self.recording, self.sorting, format=format, sparsity=sparsity_, name=self.extension_class.extension_name ) sorting_analyzer.compute("random_spikes", max_spikes_per_unit=50, seed=2205) diff --git a/src/spikeinterface/postprocessing/tests/test_align_sorting.py b/src/spikeinterface/postprocessing/tests/test_align_sorting.py index e5c70ae4b2..a02e224984 100644 --- a/src/spikeinterface/postprocessing/tests/test_align_sorting.py +++ b/src/spikeinterface/postprocessing/tests/test_align_sorting.py @@ -1,6 +1,5 @@ import pytest import shutil -from pathlib import Path import pytest @@ -11,11 +10,6 @@ from spikeinterface.postprocessing import align_sorting -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "postprocessing" -else: - cache_folder = Path("cache_folder") / "postprocessing" - def test_align_sorting(): sorting = generate_sorting(durations=[10.0], seed=0) diff --git a/src/spikeinterface/postprocessing/tests/test_principal_component.py b/src/spikeinterface/postprocessing/tests/test_principal_component.py index d94d7ea586..f9b847ec22 100644 --- a/src/spikeinterface/postprocessing/tests/test_principal_component.py +++ b/src/spikeinterface/postprocessing/tests/test_principal_component.py @@ -5,7 +5,7 @@ import numpy as np from spikeinterface.postprocessing import ComputePrincipalComponents, compute_principal_components -from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite, cache_folder +from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite DEBUG = False @@ -100,12 +100,12 @@ def test_compute_for_all_spikes(self): sorting_analyzer.compute("principal_components", mode="by_channel_local", n_components=n_components) ext = sorting_analyzer.get_extension("principal_components") - pc_file1 = cache_folder / "all_pc1.npy" + pc_file1 = self.cache_folder / "all_pc1.npy" ext.run_for_all_spikes(pc_file1, chunk_size=10000, n_jobs=1) all_pc1 = np.load(pc_file1) assert all_pc1.shape[0] == num_spikes - pc_file2 = cache_folder / "all_pc2.npy" + pc_file2 = self.cache_folder / "all_pc2.npy" ext.run_for_all_spikes(pc_file2, chunk_size=10000, n_jobs=2) all_pc2 = np.load(pc_file2) diff --git a/src/spikeinterface/postprocessing/tests/test_template_similarity.py b/src/spikeinterface/postprocessing/tests/test_template_similarity.py index 534c909592..1693530454 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_similarity.py +++ b/src/spikeinterface/postprocessing/tests/test_template_similarity.py @@ -2,7 +2,6 @@ from spikeinterface.postprocessing.tests.common_extension_tests import ( AnalyzerExtensionCommonTestSuite, - get_sorting_analyzer, get_dataset, ) @@ -15,30 +14,28 @@ class SimilarityExtensionTest(AnalyzerExtensionCommonTestSuite, unittest.TestCas dict(method="cosine_similarity"), ] + def test_check_equal_template_with_distribution_overlap(self): + recording, sorting = get_dataset() -def test_check_equal_template_with_distribution_overlap(): + sorting_analyzer = self.get_sorting_analyzer(recording=recording, sorting=sorting, sparsity=None) + sorting_analyzer.compute("random_spikes") + sorting_analyzer.compute("waveforms") + sorting_analyzer.compute("templates") - recording, sorting = get_dataset() + wf_ext = sorting_analyzer.get_extension("waveforms") - sorting_analyzer = get_sorting_analyzer(recording, sorting, sparsity=None) - sorting_analyzer.compute("random_spikes") - sorting_analyzer.compute("waveforms") - sorting_analyzer.compute("templates") - - wf_ext = sorting_analyzer.get_extension("waveforms") - - for unit_id0 in sorting_analyzer.unit_ids: - waveforms0 = wf_ext.get_waveforms_one_unit(unit_id0) - for unit_id1 in sorting_analyzer.unit_ids: - if unit_id0 == unit_id1: - continue - waveforms1 = wf_ext.get_waveforms_one_unit(unit_id1) - check_equal_template_with_distribution_overlap(waveforms0, waveforms1) + for unit_id0 in sorting_analyzer.unit_ids: + waveforms0 = wf_ext.get_waveforms_one_unit(unit_id0) + for unit_id1 in sorting_analyzer.unit_ids: + if unit_id0 == unit_id1: + continue + waveforms1 = wf_ext.get_waveforms_one_unit(unit_id1) + check_equal_template_with_distribution_overlap(waveforms0, waveforms1) if __name__ == "__main__": - # test = SimilarityExtensionTest() + test = SimilarityExtensionTest() # test.setUpClass() # test.test_extension() - test_check_equal_template_with_distribution_overlap() + # test_check_equal_template_with_distribution_overlap() diff --git a/src/spikeinterface/preprocessing/deepinterpolation/tests/test_deepinterpolation.py b/src/spikeinterface/preprocessing/deepinterpolation/tests/test_deepinterpolation.py index 7067449944..46c0fbb29c 100644 --- a/src/spikeinterface/preprocessing/deepinterpolation/tests/test_deepinterpolation.py +++ b/src/spikeinterface/preprocessing/deepinterpolation/tests/test_deepinterpolation.py @@ -25,12 +25,6 @@ HAVE_DEEPINTERPOLATION = False -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "deepinterpolation" -else: - cache_folder = Path("cache_folder") / "deepinterpolation" - - def recording_and_shape(): num_cols = 2 num_rows = 64 @@ -44,7 +38,7 @@ def recording_and_shape(): return recording, desired_shape -@pytest.fixture +@pytest.fixture(scope="module") def recording_and_shape_fixture(): return recording_and_shape() @@ -73,9 +67,10 @@ def test_deepinterpolation_generator_borders(recording_and_shape_fixture): @pytest.mark.skipif(not HAVE_DEEPINTERPOLATION, reason="requires deepinterpolation") @pytest.mark.dependency() -def test_deepinterpolation_training(recording_and_shape_fixture): +def test_deepinterpolation_training(recording_and_shape_fixture, create_cache_folder): recording, desired_shape = recording_and_shape_fixture + cache_folder = create_cache_folder model_folder = Path(cache_folder) / "training" # train model_path = train_deepinterpolation( @@ -93,15 +88,15 @@ def test_deepinterpolation_training(recording_and_shape_fixture): run_uid="si_test", pre_post_omission=1, desired_shape=desired_shape, - nb_workers=1, ) print(model_path) @pytest.mark.skipif(not HAVE_DEEPINTERPOLATION, reason="requires deepinterpolation") @pytest.mark.dependency(depends=["test_deepinterpolation_training"]) -def test_deepinterpolation_transfer(recording_and_shape_fixture, tmp_path): +def test_deepinterpolation_transfer(recording_and_shape_fixture, tmp_path, create_cache_folder): recording, desired_shape = recording_and_shape_fixture + cache_folder = create_cache_folder existing_model_path = Path(cache_folder) / "training" / "si_test_training_model.h5" model_folder = Path(tmp_path) / "transfer" @@ -128,9 +123,10 @@ def test_deepinterpolation_transfer(recording_and_shape_fixture, tmp_path): @pytest.mark.skipif(not HAVE_DEEPINTERPOLATION, reason="requires deepinterpolation") @pytest.mark.dependency(depends=["test_deepinterpolation_training"]) -def test_deepinterpolation_inference(recording_and_shape_fixture): - recording, desired_shape = recording_and_shape_fixture +def test_deepinterpolation_inference(recording_and_shape_fixture, create_cache_folder): + recording, _ = recording_and_shape_fixture pre_frame = post_frame = 20 + cache_folder = create_cache_folder existing_model_path = Path(cache_folder) / "training" / "si_test_training_model.h5" recording_di = deepinterpolate( @@ -154,9 +150,10 @@ def test_deepinterpolation_inference(recording_and_shape_fixture): @pytest.mark.skipif(not HAVE_DEEPINTERPOLATION, reason="requires deepinterpolation") @pytest.mark.dependency(depends=["test_deepinterpolation_training"]) -def test_deepinterpolation_inference_multi_job(recording_and_shape_fixture): - recording, desired_shape = recording_and_shape_fixture +def test_deepinterpolation_inference_multi_job(recording_and_shape_fixture, create_cache_folder): + recording, _ = recording_and_shape_fixture pre_frame = post_frame = 20 + cache_folder = create_cache_folder existing_model_path = Path(cache_folder) / "training" / "si_test_training_model.h5" recording_di = deepinterpolate( diff --git a/src/spikeinterface/preprocessing/tests/test_align_snippets.py b/src/spikeinterface/preprocessing/tests/test_align_snippets.py index 488c9adeb9..104c911278 100644 --- a/src/spikeinterface/preprocessing/tests/test_align_snippets.py +++ b/src/spikeinterface/preprocessing/tests/test_align_snippets.py @@ -3,19 +3,12 @@ but check only for BaseRecording general methods. """ -from pathlib import Path import pytest import numpy as np from spikeinterface.core import generate_snippets from spikeinterface.preprocessing.align_snippets import AlignSnippets -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "core" -else: - cache_folder = Path("cache_folder") / "core" - cache_folder.mkdir(exist_ok=True, parents=True) - def test_AlignSnippets(): duration = [4, 3] diff --git a/src/spikeinterface/preprocessing/tests/test_average_across_direction.py b/src/spikeinterface/preprocessing/tests/test_average_across_direction.py index 9543a669bc..dc3edc3b1d 100644 --- a/src/spikeinterface/preprocessing/tests/test_average_across_direction.py +++ b/src/spikeinterface/preprocessing/tests/test_average_across_direction.py @@ -8,13 +8,6 @@ import numpy as np -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "preprocessing" -else: - cache_folder = Path("cache_folder") / "preprocessing" - -set_global_tmp_folder(cache_folder) - def test_average_across_direction(): # gradient recording with 100 samples and 10 channels diff --git a/src/spikeinterface/preprocessing/tests/test_clip.py b/src/spikeinterface/preprocessing/tests/test_clip.py index 990730cc36..724ba2c963 100644 --- a/src/spikeinterface/preprocessing/tests/test_clip.py +++ b/src/spikeinterface/preprocessing/tests/test_clip.py @@ -1,8 +1,3 @@ -import pytest -from pathlib import Path -import shutil - -from spikeinterface import set_global_tmp_folder from spikeinterface.core import generate_recording from spikeinterface.preprocessing import clip, blank_staturation @@ -10,14 +5,6 @@ import numpy as np -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "preprocessing" -else: - cache_folder = Path("cache_folder") / "preprocessing" - -set_global_tmp_folder(cache_folder) - - def test_clip(): rec = generate_recording() diff --git a/src/spikeinterface/preprocessing/tests/test_depth_order.py b/src/spikeinterface/preprocessing/tests/test_depth_order.py index bc959b8ddb..b0dbc2a8da 100644 --- a/src/spikeinterface/preprocessing/tests/test_depth_order.py +++ b/src/spikeinterface/preprocessing/tests/test_depth_order.py @@ -1,20 +1,11 @@ import pytest -from pathlib import Path -from spikeinterface import set_global_tmp_folder from spikeinterface.core import NumpyRecording from spikeinterface.preprocessing import DepthOrderRecording, depth_order import numpy as np -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "preprocessing" -else: - cache_folder = Path("cache_folder") / "preprocessing" - -set_global_tmp_folder(cache_folder) - def test_depth_order(): # gradient recording with 100 samples and 10 channels diff --git a/src/spikeinterface/preprocessing/tests/test_directional_derivative.py b/src/spikeinterface/preprocessing/tests/test_directional_derivative.py index d863ea9c59..5f887ae35c 100644 --- a/src/spikeinterface/preprocessing/tests/test_directional_derivative.py +++ b/src/spikeinterface/preprocessing/tests/test_directional_derivative.py @@ -1,20 +1,9 @@ -import pytest -from pathlib import Path - -from spikeinterface import set_global_tmp_folder from spikeinterface.core import NumpyRecording from spikeinterface.preprocessing import DirectionalDerivativeRecording, directional_derivative import numpy as np -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "preprocessing" -else: - cache_folder = Path("cache_folder") / "preprocessing" - -set_global_tmp_folder(cache_folder) - def test_directional_derivative(): # gradient recording with 100 samples and 10 channels diff --git a/src/spikeinterface/preprocessing/tests/test_filter.py b/src/spikeinterface/preprocessing/tests/test_filter.py index fc28463dff..68790b3273 100644 --- a/src/spikeinterface/preprocessing/tests/test_filter.py +++ b/src/spikeinterface/preprocessing/tests/test_filter.py @@ -1,5 +1,4 @@ import pytest -from pathlib import Path import numpy as np from spikeinterface.core import generate_recording @@ -7,13 +6,6 @@ from spikeinterface.preprocessing import filter, bandpass_filter, notch_filter -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "preprocessing" -else: - cache_folder = Path("cache_folder") / "preprocessing" - -set_global_tmp_folder(cache_folder) - def test_filter(): rec = generate_recording() diff --git a/src/spikeinterface/preprocessing/tests/test_filter_gaussian.py b/src/spikeinterface/preprocessing/tests/test_filter_gaussian.py index 10fdc5e8d4..54682f2e94 100644 --- a/src/spikeinterface/preprocessing/tests/test_filter_gaussian.py +++ b/src/spikeinterface/preprocessing/tests/test_filter_gaussian.py @@ -9,19 +9,10 @@ from spikeinterface.core import NumpyRecording -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "preprocessing" / "gaussian_bandpass_filter" -else: - cache_folder = Path("cache_folder") / "preprocessing" / "gaussian_bandpass_filter" - -set_global_tmp_folder(cache_folder) -cache_folder.mkdir(parents=True, exist_ok=True) - - -def test_filter_gaussian(): +def test_filter_gaussian(tmp_path): recording = generate_recording(num_channels=3) recording.annotate(is_filtered=True) - recording = recording.save(folder=cache_folder / "recording") + recording = recording.save(folder=tmp_path / "recording") rec_filtered = gaussian_filter(recording) @@ -35,8 +26,8 @@ def test_filter_gaussian(): saved_loaded = load_extractor(rec_filtered.to_dict()) check_recordings_equal(rec_filtered, saved_loaded, return_scaled=False) - saved_1job = rec_filtered.save(folder=cache_folder / "1job") - saved_2job = rec_filtered.save(folder=cache_folder / "2job", n_jobs=2, chunk_duration="1s") + saved_1job = rec_filtered.save(folder=tmp_path / "1job") + saved_2job = rec_filtered.save(folder=tmp_path / "2job", n_jobs=2, chunk_duration="1s") for seg_idx in range(rec_filtered.get_num_segments()): original_trace = rec_filtered.get_traces(seg_idx) diff --git a/src/spikeinterface/preprocessing/tests/test_motion.py b/src/spikeinterface/preprocessing/tests/test_motion.py index a7f3fe1efa..e79fda1ad8 100644 --- a/src/spikeinterface/preprocessing/tests/test_motion.py +++ b/src/spikeinterface/preprocessing/tests/test_motion.py @@ -9,18 +9,10 @@ import numpy as np -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "preprocessing" -else: - cache_folder = Path("cache_folder") / "preprocessing" - -print(cache_folder.absolute()) - - -def test_estimate_and_correct_motion(): +def test_estimate_and_correct_motion(create_cache_folder): + cache_folder = create_cache_folder rec = generate_recording(durations=[30.0], num_channels=12) - print(rec) folder = cache_folder / "estimate_and_correct_motion" diff --git a/src/spikeinterface/preprocessing/tests/test_normalize_scale.py b/src/spikeinterface/preprocessing/tests/test_normalize_scale.py index 69e45425c1..576b570832 100644 --- a/src/spikeinterface/preprocessing/tests/test_normalize_scale.py +++ b/src/spikeinterface/preprocessing/tests/test_normalize_scale.py @@ -9,14 +9,6 @@ import numpy as np -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "preprocessing" -else: - cache_folder = Path("cache_folder") / "preprocessing" - -set_global_tmp_folder(cache_folder) - - def test_normalize_by_quantile(): rec = generate_recording() diff --git a/src/spikeinterface/preprocessing/tests/test_rectify.py b/src/spikeinterface/preprocessing/tests/test_rectify.py index cca41ebf7d..b8bb31015e 100644 --- a/src/spikeinterface/preprocessing/tests/test_rectify.py +++ b/src/spikeinterface/preprocessing/tests/test_rectify.py @@ -9,14 +9,6 @@ import numpy as np -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "preprocessing" -else: - cache_folder = Path("cache_folder") / "preprocessing" - -set_global_tmp_folder(cache_folder) - - def test_rectify(): rec = generate_recording() diff --git a/src/spikeinterface/preprocessing/tests/test_remove_artifacts.py b/src/spikeinterface/preprocessing/tests/test_remove_artifacts.py index b8a6e83f67..3461bf9b1b 100644 --- a/src/spikeinterface/preprocessing/tests/test_remove_artifacts.py +++ b/src/spikeinterface/preprocessing/tests/test_remove_artifacts.py @@ -1,18 +1,10 @@ import pytest -from pathlib import Path + import numpy as np -from spikeinterface import set_global_tmp_folder from spikeinterface.core import generate_recording from spikeinterface.preprocessing import remove_artifacts -# if hasattr(pytest, "global_test_folder"): -# cache_folder = pytest.global_test_folder / "preprocessing" -# else: -# cache_folder = Path("cache_folder") / "preprocessing" - -# set_global_tmp_folder(cache_folder) - def test_remove_artifacts(): # one segment only diff --git a/src/spikeinterface/preprocessing/tests/test_resample.py b/src/spikeinterface/preprocessing/tests/test_resample.py index 2fa76ffe08..df17feaaf4 100644 --- a/src/spikeinterface/preprocessing/tests/test_resample.py +++ b/src/spikeinterface/preprocessing/tests/test_resample.py @@ -1,18 +1,9 @@ -import pytest -from pathlib import Path - - from spikeinterface.preprocessing import resample from spikeinterface.core import NumpyRecording import numpy as np -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "preprocessing" -else: - cache_folder = Path("cache_folder") / "preprocessing" - DEBUG = False # DEBUG = True diff --git a/src/spikeinterface/preprocessing/tests/test_silence.py b/src/spikeinterface/preprocessing/tests/test_silence.py index a362f4dfbd..ed11eb8fdd 100644 --- a/src/spikeinterface/preprocessing/tests/test_silence.py +++ b/src/spikeinterface/preprocessing/tests/test_silence.py @@ -1,8 +1,5 @@ import pytest -from pathlib import Path -import shutil -from spikeinterface import set_global_tmp_folder from spikeinterface.core import generate_recording from spikeinterface.preprocessing import silence_periods @@ -13,15 +10,10 @@ import numpy as np -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "preprocessing" -else: - cache_folder = Path("cache_folder") / "preprocessing" +def test_silence(create_cache_folder): -set_global_tmp_folder(cache_folder) + cache_folder = create_cache_folder - -def test_silence(): rec = generate_recording() rec0 = silence_periods(rec, list_periods=[[[0, 1000], [5000, 6000]], []], mode="zeros") diff --git a/src/spikeinterface/preprocessing/tests/test_whiten.py b/src/spikeinterface/preprocessing/tests/test_whiten.py index 40674a08f4..c3d1544869 100644 --- a/src/spikeinterface/preprocessing/tests/test_whiten.py +++ b/src/spikeinterface/preprocessing/tests/test_whiten.py @@ -1,21 +1,13 @@ import pytest import numpy as np -from pathlib import Path -from spikeinterface import set_global_tmp_folder from spikeinterface.core import generate_recording from spikeinterface.preprocessing import whiten, scale, compute_whitening_matrix -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "preprocessing" -else: - cache_folder = Path("cache_folder") / "preprocessing" -set_global_tmp_folder(cache_folder) - - -def test_whiten(): +def test_whiten(create_cache_folder): + cache_folder = create_cache_folder rec = generate_recording(num_channels=4) print(rec.get_channel_locations()) diff --git a/src/spikeinterface/preprocessing/tests/test_zero_padding.py b/src/spikeinterface/preprocessing/tests/test_zero_padding.py index 3ece8a0e0d..dfcc7b661b 100644 --- a/src/spikeinterface/preprocessing/tests/test_zero_padding.py +++ b/src/spikeinterface/preprocessing/tests/test_zero_padding.py @@ -9,13 +9,6 @@ from spikeinterface.preprocessing import zero_channel_pad, bandpass_filter, phase_shift from spikeinterface.preprocessing.zero_channel_pad import TracePaddedRecording -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "preprocessing" -else: - cache_folder = Path("cache_folder") / "preprocessing" - -set_global_tmp_folder(cache_folder) - def test_zero_padding_channel(): num_original_channels = 4 diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index 5a7d43cbae..2d4eeb360b 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -38,11 +38,6 @@ from spikeinterface.core.basesorting import minimum_spike_dtype -# if hasattr(pytest, "global_test_folder"): -# cache_folder = pytest.global_test_folder / "qualitymetrics" -# else: -# cache_folder = Path("cache_folder") / "qualitymetrics" - job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py index 3ae879e3f2..28869ba5ff 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py @@ -16,12 +16,6 @@ ) -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "qualitymetrics" -else: - cache_folder = Path("cache_folder") / "qualitymetrics" - - job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") diff --git a/src/spikeinterface/sorters/basesorter.py b/src/spikeinterface/sorters/basesorter.py index cdd5bc1abb..2f87065d9f 100644 --- a/src/spikeinterface/sorters/basesorter.py +++ b/src/spikeinterface/sorters/basesorter.py @@ -102,7 +102,7 @@ def initialize_folder(cls, recording, output_folder, verbose, remove_existing_fo # installed ? if not cls.is_installed(): raise Exception( - f"The sorter {cls.sorter_name} is not installed." f"Please install it with: \n{cls.installation_mesg} " + f"The sorter {cls.sorter_name} is not installed. Please install it with:\n{cls.installation_mesg}" ) if not isinstance(recording, BaseRecordingSnippets): diff --git a/src/spikeinterface/sorters/external/tests/test_docker_containers.py b/src/spikeinterface/sorters/external/tests/test_docker_containers.py index 42d7e48a2e..f5c42eb6d1 100644 --- a/src/spikeinterface/sorters/external/tests/test_docker_containers.py +++ b/src/spikeinterface/sorters/external/tests/test_docker_containers.py @@ -1,19 +1,12 @@ import os -import shutil import pytest -from pathlib import Path from spikeinterface import generate_ground_truth_recording from spikeinterface.core.core_tools import is_editable_mode import spikeinterface.extractors as se import spikeinterface.sorters as ss -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "sorters" -else: - cache_folder = Path("cache_folder") / "sorters" - ON_GITHUB = bool(os.getenv("GITHUB_ACTIONS")) @@ -40,66 +33,77 @@ def run_kwargs(): return generate_run_kwargs() -def test_spykingcircus(run_kwargs): - sorting = ss.run_sorter("spykingcircus", output_folder=cache_folder / "spykingcircus", **run_kwargs) +def test_spykingcircus(run_kwargs, create_cache_folder): + cache_folder = create_cache_folder + print(cache_folder) + sorting = ss.run_sorter("spykingcircus", folder=cache_folder / "spykingcircus", **run_kwargs) print("resulting sorting") print(sorting) -def test_mountainsort4(run_kwargs): - sorting = ss.run_sorter("mountainsort4", output_folder=cache_folder / "mountainsort4", **run_kwargs) +def test_mountainsort4(run_kwargs, create_cache_folder): + cache_folder = create_cache_folder + sorting = ss.run_sorter("mountainsort4", folder=cache_folder / "mountainsort4", **run_kwargs) print("resulting sorting") print(sorting) -def test_mountainsort5(run_kwargs): - sorting = ss.run_sorter("mountainsort5", output_folder=cache_folder / "mountainsort5", **run_kwargs) +def test_mountainsort5(run_kwargs, create_cache_folder): + cache_folder = create_cache_folder + sorting = ss.run_sorter("mountainsort5", folder=cache_folder / "mountainsort5", **run_kwargs) print("resulting sorting") print(sorting) -def test_tridesclous(run_kwargs): - sorting = ss.run_sorter("tridesclous", output_folder=cache_folder / "tridesclous", **run_kwargs) +def test_tridesclous(run_kwargs, create_cache_folder): + cache_folder = create_cache_folder + sorting = ss.run_sorter("tridesclous", folder=cache_folder / "tridesclous", **run_kwargs) print("resulting sorting") print(sorting) -def test_ironclust(run_kwargs): - sorting = ss.run_sorter("ironclust", output_folder=cache_folder / "ironclust", fGpu=False, **run_kwargs) +def test_ironclust(run_kwargs, create_cache_folder): + cache_folder = create_cache_folder + sorting = ss.run_sorter("ironclust", folder=cache_folder / "ironclust", fGpu=False, **run_kwargs) print("resulting sorting") print(sorting) -def test_waveclus(run_kwargs): - sorting = ss.run_sorter(sorter_name="waveclus", output_folder=cache_folder / "waveclus", **run_kwargs) +def test_waveclus(run_kwargs, create_cache_folder): + cache_folder = create_cache_folder + sorting = ss.run_sorter(sorter_name="waveclus", folder=cache_folder / "waveclus", **run_kwargs) print("resulting sorting") print(sorting) -def test_hdsort(run_kwargs): - sorting = ss.run_sorter(sorter_name="hdsort", output_folder=cache_folder / "hdsort", **run_kwargs) +def test_hdsort(run_kwargs, create_cache_folder): + cache_folder = create_cache_folder + sorting = ss.run_sorter(sorter_name="hdsort", folder=cache_folder / "hdsort", **run_kwargs) print("resulting sorting") print(sorting) -def test_kilosort1(run_kwargs): - sorting = ss.run_sorter(sorter_name="kilosort", output_folder=cache_folder / "kilosort", useGPU=False, **run_kwargs) +def test_kilosort1(run_kwargs, create_cache_folder): + cache_folder = create_cache_folder + sorting = ss.run_sorter(sorter_name="kilosort", folder=cache_folder / "kilosort", useGPU=False, **run_kwargs) print("resulting sorting") print(sorting) -def test_combinato(run_kwargs): +def test_combinato(run_kwargs, create_cache_folder): + cache_folder = create_cache_folder rec = run_kwargs["recording"] channels = rec.get_channel_ids()[0:1] rec_one_channel = rec.channel_slice(channels) run_kwargs["recording"] = rec_one_channel - sorting = ss.run_sorter(sorter_name="combinato", output_folder=cache_folder / "combinato", **run_kwargs) + sorting = ss.run_sorter(sorter_name="combinato", folder=cache_folder / "combinato", **run_kwargs) print(sorting) @pytest.mark.skip("Klusta is not supported anymore for Python>=3.8") -def test_klusta(run_kwargs): - sorting = ss.run_sorter("klusta", output_folder=cache_folder / "klusta", **run_kwargs) +def test_klusta(run_kwargs, create_cache_folder): + cache_folder = create_cache_folder + sorting = ss.run_sorter("klusta", folder=cache_folder / "klusta", **run_kwargs) print(sorting) diff --git a/src/spikeinterface/sorters/external/tests/test_kilosort4.py b/src/spikeinterface/sorters/external/tests/test_kilosort4.py index 62a1476407..dbaf3ffc5e 100644 --- a/src/spikeinterface/sorters/external/tests/test_kilosort4.py +++ b/src/spikeinterface/sorters/external/tests/test_kilosort4.py @@ -6,11 +6,6 @@ from spikeinterface.sorters import Kilosort4Sorter, run_sorter from spikeinterface.sorters.tests.common_tests import SorterCommonTestSuite -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "sorters" -else: - cache_folder = Path("cache_folder") / "sorters" - # This run several tests @pytest.mark.skipif(not Kilosort4Sorter.is_installed(), reason="kilosort4 not installed") @@ -19,11 +14,11 @@ class Kilosort4SorterCommonTestSuite(SorterCommonTestSuite, unittest.TestCase): # 4 channels is to few for KS4 def setUp(self): - if (cache_folder / "rec").is_dir(): - recording = load_extractor(cache_folder / "rec") + if (self.cache_folder / "rec").is_dir(): + recording = load_extractor(self.cache_folder / "rec") else: recording, _ = generate_ground_truth_recording(num_channels=32, durations=[60], seed=0) - recording = recording.save(folder=cache_folder / "rec", verbose=False, format="binary") + recording = recording.save(folder=self.cache_folder / "rec", verbose=False, format="binary") self.recording = recording print(self.recording) @@ -32,7 +27,7 @@ def test_with_run_skip_correction(self): sorter_name = self.SorterClass.sorter_name - output_folder = cache_folder / sorter_name + output_folder = self.cache_folder / sorter_name sorter_params = self.SorterClass.default_params() sorter_params["do_correction"] = False @@ -66,7 +61,7 @@ def test_with_run_skip_preprocessing(self): sorter_name = self.SorterClass.sorter_name - output_folder = cache_folder / sorter_name + output_folder = self.cache_folder / sorter_name sorter_params = self.SorterClass.default_params() sorter_params["skip_kilosort_preprocessing"] = True @@ -101,7 +96,7 @@ def test_with_run_skip_preprocessing_and_correction(self): sorter_name = self.SorterClass.sorter_name - output_folder = cache_folder / sorter_name + output_folder = self.cache_folder / sorter_name sorter_params = self.SorterClass.default_params() sorter_params["skip_kilosort_preprocessing"] = True diff --git a/src/spikeinterface/sorters/external/tests/test_singularity_containers.py b/src/spikeinterface/sorters/external/tests/test_singularity_containers.py index afebb91bc2..61b928b6f7 100644 --- a/src/spikeinterface/sorters/external/tests/test_singularity_containers.py +++ b/src/spikeinterface/sorters/external/tests/test_singularity_containers.py @@ -1,18 +1,11 @@ import os import pytest -from pathlib import Path from spikeinterface import generate_ground_truth_recording from spikeinterface.core.core_tools import is_editable_mode -import spikeinterface.extractors as se import spikeinterface.sorters as ss -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "sorters" -else: - cache_folder = Path("cache_folder") / "sorters" - os.environ["SINGULARITY_DISABLE_CACHE"] = "true" @@ -46,76 +39,86 @@ def run_kwargs(): return generate_run_kwargs() -def test_spykingcircus(run_kwargs): +def test_spykingcircus(run_kwargs, create_cache_folder): + cache_folder = create_cache_folder clean_singularity_cache() - sorting = ss.run_sorter("spykingcircus", output_folder=cache_folder / "spykingcircus", **run_kwargs) + sorting = ss.run_sorter("spykingcircus", folder=cache_folder / "spykingcircus", **run_kwargs) print("resulting sorting") print(sorting) -def test_mountainsort4(run_kwargs): +def test_mountainsort4(run_kwargs, create_cache_folder): + cache_folder = create_cache_folder clean_singularity_cache() - sorting = ss.run_sorter("mountainsort4", output_folder=cache_folder / "mountainsort4", **run_kwargs) + sorting = ss.run_sorter("mountainsort4", folder=cache_folder / "mountainsort4", **run_kwargs) print("resulting sorting") print(sorting) -def test_mountainsort5(run_kwargs): +def test_mountainsort5(run_kwargs, create_cache_folder): + cache_folder = create_cache_folder clean_singularity_cache() - sorting = ss.run_sorter("mountainsort5", output_folder=cache_folder / "mountainsort5", **run_kwargs) + sorting = ss.run_sorter("mountainsort5", folder=cache_folder / "mountainsort5", **run_kwargs) print("resulting sorting") print(sorting) -def test_tridesclous(run_kwargs): +def test_tridesclous(run_kwargs, create_cache_folder): + cache_folder = create_cache_folder clean_singularity_cache() - sorting = ss.run_sorter("tridesclous", output_folder=cache_folder / "tridesclous", **run_kwargs) + sorting = ss.run_sorter("tridesclous", folder=cache_folder / "tridesclous", **run_kwargs) print("resulting sorting") print(sorting) -def test_ironclust(run_kwargs): +def test_ironclust(run_kwargs, create_cache_folder): + cache_folder = create_cache_folder clean_singularity_cache() - sorting = ss.run_sorter("ironclust", output_folder=cache_folder / "ironclust", fGpu=False, **run_kwargs) + sorting = ss.run_sorter("ironclust", folder=cache_folder / "ironclust", fGpu=False, **run_kwargs) print("resulting sorting") print(sorting) -def test_waveclus(run_kwargs): +def test_waveclus(run_kwargs, create_cache_folder): + cache_folder = create_cache_folder clean_singularity_cache() - sorting = ss.run_sorter(sorter_name="waveclus", output_folder=cache_folder / "waveclus", **run_kwargs) + sorting = ss.run_sorter(sorter_name="waveclus", folder=cache_folder / "waveclus", **run_kwargs) print("resulting sorting") print(sorting) -def test_hdsort(run_kwargs): +def test_hdsort(run_kwargs, create_cache_folder): + cache_folder = create_cache_folder clean_singularity_cache() - sorting = ss.run_sorter(sorter_name="hdsort", output_folder=cache_folder / "hdsort", **run_kwargs) + sorting = ss.run_sorter(sorter_name="hdsort", folder=cache_folder / "hdsort", **run_kwargs) print("resulting sorting") print(sorting) -def test_kilosort1(run_kwargs): +def test_kilosort1(run_kwargs, create_cache_folder): + cache_folder = create_cache_folder clean_singularity_cache() - sorting = ss.run_sorter(sorter_name="kilosort", output_folder=cache_folder / "kilosort", useGPU=False, **run_kwargs) + sorting = ss.run_sorter(sorter_name="kilosort", folder=cache_folder / "kilosort", useGPU=False, **run_kwargs) print("resulting sorting") print(sorting) -def test_combinato(run_kwargs): +def test_combinato(run_kwargs, create_cache_folder): + cache_folder = create_cache_folder clean_singularity_cache() rec = run_kwargs["recording"] channels = rec.get_channel_ids()[0:1] rec_one_channel = rec.channel_slice(channels) run_kwargs["recording"] = rec_one_channel - sorting = ss.run_sorter(sorter_name="combinato", output_folder=cache_folder / "combinato", **run_kwargs) + sorting = ss.run_sorter(sorter_name="combinato", folder=cache_folder / "combinato", **run_kwargs) print(sorting) @pytest.mark.skip("Klusta is not supported anymore for Python>=3.8") -def test_klusta(run_kwargs): +def test_klusta(run_kwargs, create_cache_folder): + cache_folder = create_cache_folder clean_singularity_cache() - sorting = ss.run_sorter("klusta", output_folder=cache_folder / "klusta", **run_kwargs) + sorting = ss.run_sorter("klusta", folder=cache_folder / "klusta", **run_kwargs) print(sorting) diff --git a/src/spikeinterface/sorters/external/tests/test_singularity_containers_gpu.py b/src/spikeinterface/sorters/external/tests/test_singularity_containers_gpu.py index c8e35f33b3..eb238abdf4 100644 --- a/src/spikeinterface/sorters/external/tests/test_singularity_containers_gpu.py +++ b/src/spikeinterface/sorters/external/tests/test_singularity_containers_gpu.py @@ -42,38 +42,38 @@ def run_kwargs(): def test_kilosort2(run_kwargs): clean_singularity_cache() - sorting = ss.run_sorter(sorter_name="kilosort2", output_folder="kilosort2", **run_kwargs) + sorting = ss.run_sorter(sorter_name="kilosort2", folder="kilosort2", **run_kwargs) print(sorting) def test_kilosort2_5(run_kwargs): clean_singularity_cache() - sorting = ss.run_sorter(sorter_name="kilosort2_5", output_folder="kilosort2_5", **run_kwargs) + sorting = ss.run_sorter(sorter_name="kilosort2_5", folder="kilosort2_5", **run_kwargs) print(sorting) def test_kilosort3(run_kwargs): clean_singularity_cache() - sorting = ss.run_sorter(sorter_name="kilosort3", output_folder="kilosort3", **run_kwargs) + sorting = ss.run_sorter(sorter_name="kilosort3", folder="kilosort3", **run_kwargs) print(sorting) def test_kilosort4(run_kwargs): clean_singularity_cache() - sorting = ss.run_sorter(sorter_name="kilosort4", output_folder="kilosort4", **run_kwargs) + sorting = ss.run_sorter(sorter_name="kilosort4", folder="kilosort4", **run_kwargs) print(sorting) def test_pykilosort(run_kwargs): clean_singularity_cache() - sorting = ss.run_sorter(sorter_name="pykilosort", output_folder="pykilosort", **run_kwargs) + sorting = ss.run_sorter(sorter_name="pykilosort", folder="pykilosort", **run_kwargs) print(sorting) @pytest.mark.skip("YASS is not supported anymore for Python>=3.8") def test_yass(run_kwargs): clean_singularity_cache() - sorting = ss.run_sorter(sorter_name="yass", output_folder="yass", **run_kwargs) + sorting = ss.run_sorter(sorter_name="yass", folder="yass", **run_kwargs) print(sorting) diff --git a/src/spikeinterface/sorters/tests/common_tests.py b/src/spikeinterface/sorters/tests/common_tests.py index 5339918f11..0bcd38c433 100644 --- a/src/spikeinterface/sorters/tests/common_tests.py +++ b/src/spikeinterface/sorters/tests/common_tests.py @@ -1,18 +1,12 @@ from __future__ import annotations import pytest -from pathlib import Path import shutil from spikeinterface import generate_ground_truth_recording from spikeinterface.sorters import run_sorter from spikeinterface.core.snippets_tools import snippets_from_sorting -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "sorters" -else: - cache_folder = Path("cache_folder") / "sorters" - class SorterCommonTestSuite: """ @@ -23,12 +17,16 @@ class SorterCommonTestSuite: SorterClass = None + @pytest.fixture(autouse=True) + def create_cache_folder(self, tmp_path_factory): + self.cache_folder = tmp_path_factory.mktemp("cache_folder") + def setUp(self): recording, sorting_gt = generate_ground_truth_recording(num_channels=4, durations=[60], seed=0) - rec_folder = cache_folder / "rec" + rec_folder = self.cache_folder / "rec" if rec_folder.is_dir(): shutil.rmtree(rec_folder) - self.recording = recording.save(folder=cache_folder / "rec", verbose=False, format="binary") + self.recording = recording.save(folder=self.cache_folder / "rec", verbose=False, format="binary") print(self.recording) def test_with_run(self): @@ -39,7 +37,7 @@ def test_with_run(self): sorter_name = self.SorterClass.sorter_name - output_folder = cache_folder / sorter_name + output_folder = self.cache_folder / sorter_name sorter_params = self.SorterClass.default_params() @@ -77,11 +75,15 @@ class SnippetsSorterCommonTestSuite: * run once """ + @pytest.fixture(autouse=True) + def create_cache_folder(self, tmp_path_factory): + self.cache_folder = tmp_path_factory.mktemp("cache_folder") + SorterClass = None def setUp(self): recording, sorting_gt = generate_ground_truth_recording(num_channels=4, durations=[60], seed=0) - snippets_folder = cache_folder / "snippets" + snippets_folder = self.cache_folder / "snippets" if snippets_folder.is_dir(): shutil.rmtree(snippets_folder) @@ -98,7 +100,7 @@ def test_with_run(self): sorter_name = self.SorterClass.sorter_name - output_folder = cache_folder / sorter_name + output_folder = self.cache_folder / sorter_name sorter_params = self.SorterClass.default_params() diff --git a/src/spikeinterface/sorters/tests/test_container_tools.py b/src/spikeinterface/sorters/tests/test_container_tools.py index 16d1e0a4a4..3ae03abff1 100644 --- a/src/spikeinterface/sorters/tests/test_container_tools.py +++ b/src/spikeinterface/sorters/tests/test_container_tools.py @@ -11,13 +11,10 @@ ON_GITHUB = bool(os.getenv("GITHUB_ACTIONS")) -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "sorters" -else: - cache_folder = Path("cache_folder") / "sorters" - -def setup_module(): +@pytest.fixture(scope="module") +def setup_module(tmp_path_factory): + cache_folder = tmp_path_factory.mktemp("cache_folder") test_dirs = [cache_folder / "mono", cache_folder / "multi"] for test_dir in test_dirs: if test_dir.exists(): @@ -27,9 +24,11 @@ def setup_module(): rec2, _ = generate_ground_truth_recording(durations=[10, 10, 10]) rec2 = rec2.save(folder=cache_folder / "multi") + return cache_folder -def test_find_recording_folders(): +def test_find_recording_folders(setup_module): + cache_folder = setup_module rec1 = si.load_extractor(cache_folder / "mono") rec2 = si.load_extractor(cache_folder / "multi" / "binary.json", base_folder=cache_folder / "multi") diff --git a/src/spikeinterface/sorters/tests/test_launcher.py b/src/spikeinterface/sorters/tests/test_launcher.py index 4fca09e2a1..362d45cbff 100644 --- a/src/spikeinterface/sorters/tests/test_launcher.py +++ b/src/spikeinterface/sorters/tests/test_launcher.py @@ -1,4 +1,3 @@ -import os import sys import shutil import time @@ -6,49 +5,37 @@ import pytest from pathlib import Path -from spikeinterface.core import load_extractor - from spikeinterface import generate_ground_truth_recording from spikeinterface.sorters import run_sorter_jobs, run_sorter_by_property -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "sorters" -else: - cache_folder = Path("cache_folder") / "sorters" - -base_output = cache_folder / "sorter_output" - # no need to have many -num_recordings = 2 -sorters = ["tridesclous2"] +NUM_RECORDINGS = 2 +SORTERS = ["tridesclous2"] -def setup_module(): - base_seed = 42 - for i in range(num_recordings): - rec, _ = generate_ground_truth_recording(num_channels=8, durations=[10.0], seed=base_seed + i) - rec_folder = cache_folder / f"toy_rec_{i}" - if rec_folder.is_dir(): - shutil.rmtree(rec_folder) +def create_recordings(NUM_RECORDINGS=2, base_seed=42): + recordings = [] + for i in range(NUM_RECORDINGS): + recording, _ = generate_ground_truth_recording(num_channels=8, durations=[10.0], seed=base_seed + i) if i % 2 == 0: - rec.set_channel_groups(["0"] * 4 + ["1"] * 4) + recording.set_channel_groups(["0"] * 4 + ["1"] * 4) else: - rec.set_channel_groups([0] * 4 + [1] * 4) + recording.set_channel_groups([0] * 4 + [1] * 4) + recordings.append(recording) + return recordings - rec.save(folder=rec_folder) - -def get_job_list(): +def get_job_list(base_folder): jobs = [] - for i in range(num_recordings): - for sorter_name in sorters: - recording = load_extractor(cache_folder / f"toy_rec_{i}") + recordings = create_recordings(NUM_RECORDINGS) + for i, recording in enumerate(recordings): + for sorter_name in SORTERS: kwargs = dict( sorter_name=sorter_name, recording=recording, - folder=base_output / f"{sorter_name}_rec{i}", + folder=base_folder / f"{sorter_name}_rec{i}", verbose=True, raise_error=False, ) @@ -57,31 +44,30 @@ def get_job_list(): return jobs -@pytest.fixture(scope="module") -def job_list(): - return get_job_list() +@pytest.fixture(scope="function") +def job_list(create_cache_folder): + cache_folder = create_cache_folder + folder = cache_folder / "sorting_output" + return get_job_list(folder) def test_run_sorter_jobs_loop(job_list): - if base_output.is_dir(): - shutil.rmtree(base_output) sortings = run_sorter_jobs(job_list, engine="loop", return_output=True) print(sortings) @pytest.mark.skipif(True, reason="tridesclous is already multiprocessing, joblib cannot run it in parralel") def test_run_sorter_jobs_joblib(job_list): - if base_output.is_dir(): - shutil.rmtree(base_output) sortings = run_sorter_jobs( job_list, engine="joblib", engine_kwargs=dict(n_jobs=2, backend="loky"), return_output=True ) print(sortings) -def test_run_sorter_jobs_processpoolexecutor(job_list): - if base_output.is_dir(): - shutil.rmtree(base_output) +def test_run_sorter_jobs_processpoolexecutor(job_list, create_cache_folder): + cache_folder = create_cache_folder + if (cache_folder / "sorting_output").is_dir(): + shutil.rmtree(cache_folder / "sorting_output") sortings = run_sorter_jobs( job_list, engine="processpoolexecutor", engine_kwargs=dict(max_workers=2), return_output=True ) @@ -90,8 +76,6 @@ def test_run_sorter_jobs_processpoolexecutor(job_list): @pytest.mark.skipif(True, reason="This is tested locally") def test_run_sorter_jobs_dask(job_list): - if base_output.is_dir(): - shutil.rmtree(base_output) # create a dask Client for a slurm queue from dask.distributed import Client @@ -122,11 +106,10 @@ def test_run_sorter_jobs_dask(job_list): @pytest.mark.skip("Slurm launcher need a machine with slurm") -def test_run_sorter_jobs_slurm(job_list): - if base_output.is_dir(): - shutil.rmtree(base_output) +def test_run_sorter_jobs_slurm(job_list, create_cache_folder): + cache_folder = create_cache_folder - working_folder = cache_folder / "test_run_sorters_slurm" + working_folder = cache_folder / "test_run_SORTERS_slurm" if working_folder.is_dir(): shutil.rmtree(working_folder) @@ -143,7 +126,8 @@ def test_run_sorter_jobs_slurm(job_list): ) -def test_run_sorter_by_property(): +def test_run_sorter_by_property(create_cache_folder): + cache_folder = create_cache_folder working_folder1 = cache_folder / "test_run_sorter_by_property_1" if working_folder1.is_dir(): shutil.rmtree(working_folder1) @@ -152,7 +136,9 @@ def test_run_sorter_by_property(): if working_folder2.is_dir(): shutil.rmtree(working_folder2) - rec0 = load_extractor(cache_folder / "toy_rec_0") + recordings = create_recordings(NUM_RECORDINGS) + + rec0 = recordings[0] rec0_by = rec0.split_by("group") group_names0 = list(rec0_by.keys()) @@ -161,7 +147,7 @@ def test_run_sorter_by_property(): assert "group" in sorting0.get_property_keys() assert all([g in group_names0 for g in sorting0.get_property("group")]) - rec1 = load_extractor(cache_folder / "toy_rec_1") + rec1 = recordings[1] rec1_by = rec1.split_by("group") group_names1 = list(rec1_by.keys()) @@ -172,8 +158,9 @@ def test_run_sorter_by_property(): if __name__ == "__main__": - setup_module() - job_list = get_job_list() + # setup_module() + tmp_folder = Path("tmp") + job_list = get_job_list(tmp_folder) # test_run_sorter_jobs_loop(job_list) # test_run_sorter_jobs_joblib(job_list) @@ -182,4 +169,4 @@ def test_run_sorter_by_property(): # test_run_sorter_jobs_dask(job_list) # test_run_sorter_jobs_slurm(job_list) - test_run_sorter_by_property() + test_run_sorter_by_property(tmp_folder) diff --git a/src/spikeinterface/sorters/tests/test_runsorter.py b/src/spikeinterface/sorters/tests/test_runsorter.py index df7389e844..470bdc3602 100644 --- a/src/spikeinterface/sorters/tests/test_runsorter.py +++ b/src/spikeinterface/sorters/tests/test_runsorter.py @@ -3,33 +3,25 @@ from pathlib import Path import shutil -import spikeinterface as si -from spikeinterface import download_dataset, generate_ground_truth_recording, load_extractor -from spikeinterface.extractors import read_mearec +from spikeinterface import generate_ground_truth_recording from spikeinterface.sorters import run_sorter ON_GITHUB = bool(os.getenv("GITHUB_ACTIONS")) -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "sorters" -else: - cache_folder = Path("cache_folder") / "sorters" +def _generate_recording(): + recording, _ = generate_ground_truth_recording(num_channels=8, durations=[10.0], seed=2205) + return recording -rec_folder = cache_folder / "recording" +@pytest.fixture(scope="module") +def generate_recording(): + return _generate_recording() -def setup_module(): - if rec_folder.exists(): - shutil.rmtree(rec_folder) - recording, sorting_gt = generate_ground_truth_recording(num_channels=8, durations=[10.0], seed=2205) - recording = recording.save(folder=rec_folder) - -def test_run_sorter_local(): - # local_path = download_dataset(remote_path="mearec/mearec_test_10s.h5") - # recording, sorting_true = read_mearec(local_path) - recording = load_extractor(rec_folder) +def test_run_sorter_local(generate_recording, create_cache_folder): + recording = generate_recording + cache_folder = create_cache_folder sorter_params = {"detect_threshold": 4.9} @@ -48,11 +40,9 @@ def test_run_sorter_local(): @pytest.mark.skipif(ON_GITHUB, reason="Docker tests don't run on github: test locally") -def test_run_sorter_docker(): - # mearec_filename = download_dataset(remote_path="mearec/mearec_test_10s.h5", unlock=True) - # recording, sorting_true = read_mearec(mearec_filename) - - recording = load_extractor(rec_folder) +def test_run_sorter_docker(generate_recording, create_cache_folder): + recording = generate_recording + cache_folder = create_cache_folder sorter_params = {"detect_threshold": 4.9} @@ -82,17 +72,13 @@ def test_run_sorter_docker(): @pytest.mark.skipif(ON_GITHUB, reason="Singularity tests don't run on github: test it locally") -def test_run_sorter_singularity(): - # mearec_filename = download_dataset(remote_path="mearec/mearec_test_10s.h5", unlock=True) - # recording, sorting_true = read_mearec(mearec_filename) +def test_run_sorter_singularity(generate_recording, create_cache_folder): + recording = generate_recording + cache_folder = create_cache_folder # use an output folder outside of the package. otherwise dev mode will not work - singularity_cache_folder = Path(si.__file__).parents[3] / "sandbox" - singularity_cache_folder.mkdir(exist_ok=True) - - recording = load_extractor(rec_folder) - - sorter_params = {"detect_threshold": 4.9} + # singularity_cache_folder = Path(si.__file__).parents[3] / "sandbox" + # singularity_cache_folder.mkdir(exist_ok=True) sorter_params = {"detect_threshold": 4.9} @@ -100,7 +86,7 @@ def test_run_sorter_singularity(): for installation_mode in ("dev", "pypi", "github"): print(f"\nTest with installation_mode {installation_mode}") - output_folder = singularity_cache_folder / f"sorting_tdc_singularity_{installation_mode}" + output_folder = cache_folder / f"sorting_tdc_singularity_{installation_mode}" sorting = run_sorter( "tridesclous", recording, @@ -121,7 +107,8 @@ def test_run_sorter_singularity(): if __name__ == "__main__": - setup_module() - # test_run_sorter_local() - # test_run_sorter_docker() - test_run_sorter_singularity() + rec = _generate_recording + cache_folder = Path("tmp") + # test_run_sorter_local(rec, cache_folder) + # test_run_sorter_docker(rec, cache_folder) + test_run_sorter_singularity(rec, cache_folder) diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/common_benchmark_testing.py b/src/spikeinterface/sortingcomponents/benchmark/tests/common_benchmark_testing.py index 313f19537e..1e9f8abae9 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/common_benchmark_testing.py +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/common_benchmark_testing.py @@ -21,12 +21,6 @@ ON_GITHUB = bool(os.getenv("GITHUB_ACTIONS")) -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "sortingcomponents_benchmark" -else: - cache_folder = Path("cache_folder") / "sortingcomponents_benchmark" - - def make_dataset(): recording, gt_sorting = generate_ground_truth_recording( durations=[60.0], diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_clustering.py b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_clustering.py index bb9d3b4ed1..bc36fb607c 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_clustering.py +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_clustering.py @@ -3,15 +3,15 @@ import shutil -from spikeinterface.sortingcomponents.benchmark.tests.common_benchmark_testing import make_dataset, cache_folder +from spikeinterface.sortingcomponents.benchmark.tests.common_benchmark_testing import make_dataset from spikeinterface.sortingcomponents.benchmark.benchmark_clustering import ClusteringStudy from spikeinterface.core.sortinganalyzer import create_sorting_analyzer from spikeinterface.core.template_tools import get_template_extremum_channel @pytest.mark.skip() -def test_benchmark_clustering(): - +def test_benchmark_clustering(create_cache_folder): + cache_folder = create_cache_folder job_kwargs = dict(n_jobs=0.8, chunk_duration="1s") recording, gt_sorting, gt_analyzer = make_dataset() diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py index 4837160dc0..aa9b16bb97 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py @@ -10,15 +10,14 @@ from spikeinterface.sortingcomponents.benchmark.tests.common_benchmark_testing import ( make_dataset, - cache_folder, compute_gt_templates, ) from spikeinterface.sortingcomponents.benchmark.benchmark_matching import MatchingStudy @pytest.mark.skip() -def test_benchmark_matching(): - +def test_benchmark_matching(create_cache_folder): + cache_folder = create_cache_folder job_kwargs = dict(n_jobs=0.8, chunk_duration="100ms") recording, gt_sorting, gt_analyzer = make_dataset() diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_estimation.py b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_estimation.py index dec0e612f8..696531b221 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_estimation.py @@ -5,15 +5,14 @@ from spikeinterface.sortingcomponents.benchmark.tests.common_benchmark_testing import ( make_drifting_dataset, - cache_folder, ) from spikeinterface.sortingcomponents.benchmark.benchmark_motion_estimation import MotionEstimationStudy @pytest.mark.skip() -def test_benchmark_motion_estimaton(): - +def test_benchmark_motion_estimaton(create_cache_folder): + cache_folder = create_cache_folder job_kwargs = dict(n_jobs=0.8, chunk_duration="1s") data = make_drifting_dataset() diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_interpolation.py b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_interpolation.py index bf4522df94..4b7264a9de 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_interpolation.py @@ -8,7 +8,6 @@ from spikeinterface.sortingcomponents.benchmark.tests.common_benchmark_testing import ( make_drifting_dataset, - cache_folder, ) from spikeinterface.sortingcomponents.benchmark.benchmark_motion_interpolation import MotionInterpolationStudy @@ -19,8 +18,8 @@ @pytest.mark.skip() -def test_benchmark_motion_interpolation(): - +def test_benchmark_motion_interpolation(create_cache_folder): + cache_folder = create_cache_folder job_kwargs = dict(n_jobs=0.8, chunk_duration="1s") data = make_drifting_dataset() diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_detection.py b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_detection.py index e37e8eca14..dffe1529b7 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_detection.py +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_detection.py @@ -3,14 +3,15 @@ import shutil -from spikeinterface.sortingcomponents.benchmark.tests.common_benchmark_testing import make_dataset, cache_folder +from spikeinterface.sortingcomponents.benchmark.tests.common_benchmark_testing import make_dataset from spikeinterface.sortingcomponents.benchmark.benchmark_peak_detection import PeakDetectionStudy from spikeinterface.core.sortinganalyzer import create_sorting_analyzer from spikeinterface.core.template_tools import get_template_extremum_channel @pytest.mark.skip() -def test_benchmark_peak_detection(): +def test_benchmark_peak_detection(create_cache_folder): + cache_folder = create_cache_folder job_kwargs = dict(n_jobs=0.8, chunk_duration="100ms") # recording, gt_sorting = make_dataset() diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_localization.py b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_localization.py index 8627034cef..b6f89dcd36 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_localization.py +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_localization.py @@ -3,14 +3,15 @@ import shutil -from spikeinterface.sortingcomponents.benchmark.tests.common_benchmark_testing import make_dataset, cache_folder +from spikeinterface.sortingcomponents.benchmark.tests.common_benchmark_testing import make_dataset from spikeinterface.sortingcomponents.benchmark.benchmark_peak_localization import PeakLocalizationStudy from spikeinterface.sortingcomponents.benchmark.benchmark_peak_localization import UnitLocalizationStudy @pytest.mark.skip() -def test_benchmark_peak_localization(): +def test_benchmark_peak_localization(create_cache_folder): + cache_folder = create_cache_folder job_kwargs = dict(n_jobs=0.8, chunk_duration="100ms") # recording, gt_sorting = make_dataset() @@ -55,7 +56,8 @@ def test_benchmark_peak_localization(): @pytest.mark.skip() -def test_benchmark_unit_localization(): +def test_benchmark_unit_localization(create_cache_folder): + cache_folder = create_cache_folder job_kwargs = dict(n_jobs=0.8, chunk_duration="100ms") recording, gt_sorting = make_dataset() diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_selection.py b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_selection.py index 1e65dfe6cc..a9e404292d 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_selection.py +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_selection.py @@ -2,7 +2,8 @@ @pytest.mark.skip() -def test_benchmark_peak_selection(): +def test_benchmark_peak_selection(create_cache_folder): + cache_folder = create_cache_folder pass diff --git a/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py b/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py index 36f623ebf8..597eee7a99 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py @@ -1,5 +1,5 @@ import pytest -from pathlib import Path + import shutil import numpy as np @@ -7,7 +7,6 @@ from spikeinterface.sortingcomponents.peak_detection import detect_peaks from spikeinterface.sortingcomponents.motion_estimation import estimate_motion - from spikeinterface.sortingcomponents.motion_interpolation import InterpolateMotionRecording from spikeinterface.core.node_pipeline import ExtractDenseWaveforms @@ -15,10 +14,6 @@ from spikeinterface.sortingcomponents.tests.common import make_dataset -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "sortingcomponents" -else: - cache_folder = Path("cache_folder") / "sortingcomponents" DEBUG = False @@ -29,9 +24,10 @@ plt.show() -def setup_module(): +@pytest.fixture(scope="module") +def setup_module(tmp_path_factory): recording, sorting = make_dataset() - + cache_folder = tmp_path_factory.mktemp("cache_folder") cache_folder.mkdir(parents=True, exist_ok=True) # detect and localize @@ -50,13 +46,18 @@ def setup_module(): progress_bar=True, pipeline_nodes=pipeline_nodes, ) - np.save(cache_folder / "dataset_peaks.npy", peaks) - np.save(cache_folder / "dataset_peak_locations.npy", peak_locations) + peaks_path = cache_folder / "dataset_peaks.npy" + np.save(peaks_path, peaks) + peak_location_path = cache_folder / "dataset_peak_locations.npy" + np.save(peak_location_path, peak_locations) + + return recording, sorting, cache_folder -def test_estimate_motion(): - recording, sorting = make_dataset() +def test_estimate_motion(setup_module): + # recording, sorting = make_dataset() + recording, sorting, cache_folder = setup_module peaks = np.load(cache_folder / "dataset_peaks.npy") peak_locations = np.load(cache_folder / "dataset_peak_locations.npy") diff --git a/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py b/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py index cc3434b782..de22ee010d 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py @@ -1,9 +1,5 @@ -import pytest -from pathlib import Path import numpy as np -from spikeinterface import download_dataset - from spikeinterface.sortingcomponents.motion_interpolation import ( correct_motion_on_peaks, interpolate_motion_on_traces, @@ -13,12 +9,6 @@ from spikeinterface.sortingcomponents.tests.common import make_dataset -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "sortingcomponents" -else: - cache_folder = Path("cache_folder") / "sortingcomponents" - - def make_fake_motion(rec): # make a fake motion vector duration = rec.get_total_duration() diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 6703c8b057..fdc937dc25 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -1,7 +1,6 @@ import unittest import pytest import os -from pathlib import Path if __name__ != "__main__": try: @@ -24,12 +23,6 @@ from spikeinterface.preprocessing import scale -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "widgets" -else: - cache_folder = Path("cache_folder") / "widgets" - - ON_GITHUB = bool(os.getenv("GITHUB_ACTIONS")) KACHERY_CLOUD_SET = bool(os.getenv("KACHERY_CLOUD_CLIENT_ID")) and bool(os.getenv("KACHERY_CLOUD_PRIVATE_KEY")) SKIP_SORTINGVIEW = bool(os.getenv("SKIP_SORTINGVIEW"))