Skip to content

Commit

Permalink
Remove cache folder (#2927)
Browse files Browse the repository at this point in the history
Remove pytest.global_test_folder in favor of create_cache_folder module fixture

Authored by: Paul Riganese
Co-authored-by: Alessio Buccino, Joe Ziminski
  • Loading branch information
paulrignanese authored Jun 6, 2024
1 parent 1711188 commit a9fe858
Show file tree
Hide file tree
Showing 76 changed files with 371 additions and 730 deletions.
23 changes: 4 additions & 19 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,10 @@
"widgets", "exporters", "sortingcomponents", "generation"]


# define global test folder
def pytest_sessionstart(session):
# setup_stuff
pytest.global_test_folder = Path(__file__).parent / "test_folder"
if pytest.global_test_folder.is_dir():
shutil.rmtree(pytest.global_test_folder)
pytest.global_test_folder.mkdir()

for mark_name in mark_names:
(pytest.global_test_folder / mark_name).mkdir()
@pytest.fixture(scope="module")
def create_cache_folder(tmp_path_factory):
cache_folder = tmp_path_factory.mktemp("cache_folder")
return cache_folder

def pytest_collection_modifyitems(config, items):
"""
Expand All @@ -45,12 +39,3 @@ def pytest_collection_modifyitems(config, items):
item.add_marker("sorters")
else:
item.add_marker(module)



def pytest_sessionfinish(session, exitstatus):
# teardown_stuff only if tests passed
# We don't delete the test folder in the CI because it was causing problems with the code coverage.
if exitstatus == 0:
if pytest.global_test_folder.is_dir() and not ON_GITHUB:
shutil.rmtree(pytest.global_test_folder)
17 changes: 6 additions & 11 deletions src/spikeinterface/comparison/tests/test_groundtruthstudy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,13 @@
from spikeinterface.comparison import GroundTruthStudy


if hasattr(pytest, "global_test_folder"):
cache_folder = pytest.global_test_folder / "comparison"
else:
cache_folder = Path("cache_folder") / "comparison"
cache_folder.mkdir(exist_ok=True, parents=True)

study_folder = cache_folder / "test_groundtruthstudy/"


def setup_module():
@pytest.fixture(scope="module")
def setup_module(tmp_path_factory):
study_folder = tmp_path_factory.mktemp("study_folder")
if study_folder.is_dir():
shutil.rmtree(study_folder)
create_a_study(study_folder)
return study_folder


def simple_preprocess(rec):
Expand Down Expand Up @@ -74,7 +68,8 @@ def create_a_study(study_folder):
# print(study)


def test_GroundTruthStudy():
def test_GroundTruthStudy(setup_module):
study_folder = setup_module
study = GroundTruthStudy(study_folder)
print(study)

Expand Down
17 changes: 8 additions & 9 deletions src/spikeinterface/comparison/tests/test_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,9 @@
from spikeinterface.preprocessing import bandpass_filter


if hasattr(pytest, "global_test_folder"):
cache_folder = pytest.global_test_folder / "comparison" / "hybrid"
else:
cache_folder = Path("cache_folder") / "comparison" / "hybrid"


def setup_module():
@pytest.fixture(scope="module")
def setup_module(tmp_path_factory):
cache_folder = tmp_path_factory.mktemp("cache_folder")
if cache_folder.is_dir():
shutil.rmtree(cache_folder)
cache_folder.mkdir(parents=True, exist_ok=True)
Expand All @@ -31,9 +27,11 @@ def setup_module():
wvf_extractor = extract_waveforms(
recording, sorting, folder=cache_folder / "wvf_extractor", ms_before=10.0, ms_after=10.0
)
return cache_folder


def test_hybrid_units_recording():
def test_hybrid_units_recording(setup_module):
cache_folder = setup_module
wvf_extractor = load_waveforms(cache_folder / "wvf_extractor")
print(wvf_extractor)
print(wvf_extractor.sorting_analyzer)
Expand Down Expand Up @@ -63,7 +61,8 @@ def test_hybrid_units_recording():
check_recordings_equal(hybrid_units_recording, saved_2job, return_scaled=False)


def test_hybrid_spikes_recording():
def test_hybrid_spikes_recording(setup_module):
cache_folder = setup_module
wvf_extractor = load_waveforms(cache_folder / "wvf_extractor")
recording = wvf_extractor.recording
sorting = wvf_extractor.sorting
Expand Down
19 changes: 7 additions & 12 deletions src/spikeinterface/comparison/tests/test_multisortingcomparison.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import shutil
import pytest
from pathlib import Path

import pytest
import numpy as np
Expand All @@ -9,18 +7,14 @@
from spikeinterface.extractors import NumpySorting
from spikeinterface.comparison import compare_multiple_sorters, MultiSortingComparison

if hasattr(pytest, "global_test_folder"):
cache_folder = pytest.global_test_folder / "comparison"
else:
cache_folder = Path("cache_folder") / "comparison"


multicomparison_folder = cache_folder / "saved_multisorting_comparison"


def setup_module():
@pytest.fixture(scope="module")
def setup_module(tmp_path_factory):
cache_folder = tmp_path_factory.mktemp("cache_folder")
multicomparison_folder = cache_folder / "saved_multisorting_comparison"
if multicomparison_folder.is_dir():
shutil.rmtree(multicomparison_folder)
return multicomparison_folder


def make_sorting(times1, labels1, times2, labels2, times3, labels3):
Expand All @@ -34,7 +28,8 @@ def make_sorting(times1, labels1, times2, labels2, times3, labels3):
return sorting1, sorting2, sorting3


def test_compare_multiple_sorters():
def test_compare_multiple_sorters(setup_module):
multicomparison_folder = setup_module
# simple match
sorting1, sorting2, sorting3 = make_sorting(
[100, 200, 300, 400, 500, 600, 700, 800, 900],
Expand Down
10 changes: 0 additions & 10 deletions src/spikeinterface/comparison/tests/test_templatecomparison.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,11 @@
import shutil
import pytest
from pathlib import Path
import numpy as np

from spikeinterface.core import create_sorting_analyzer, generate_ground_truth_recording
from spikeinterface.comparison import compare_templates, compare_multiple_templates


# if hasattr(pytest, "global_test_folder"):
# cache_folder = pytest.global_test_folder / "comparison"
# else:
# cache_folder = Path("cache_folder") / "comparison"


# test_dir = cache_folder / "temp_comp_test"


# def setup_module():
# if test_dir.is_dir():
# shutil.rmtree(test_dir)
Expand Down
48 changes: 24 additions & 24 deletions src/spikeinterface/core/tests/test_analyzer_extension_core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import pytest
from pathlib import Path

import shutil

Expand All @@ -11,13 +10,8 @@

import numpy as np

if hasattr(pytest, "global_test_folder"):
cache_folder = pytest.global_test_folder / "core"
else:
cache_folder = Path("cache_folder") / "core"


def get_sorting_analyzer(format="memory", sparse=True):
def get_sorting_analyzer(cache_folder, format="memory", sparse=True):
recording, sorting = generate_ground_truth_recording(
durations=[30.0],
sampling_frequency=16000.0,
Expand Down Expand Up @@ -53,7 +47,7 @@ def get_sorting_analyzer(format="memory", sparse=True):
return sorting_analyzer


def _check_result_extension(sorting_analyzer, extension_name):
def _check_result_extension(sorting_analyzer, extension_name, cache_folder):
# select unit_ids to several format
for format in ("memory", "binary_folder", "zarr"):
# for format in ("memory", ):
Expand Down Expand Up @@ -83,39 +77,42 @@ def _check_result_extension(sorting_analyzer, extension_name):
False,
],
)
def test_ComputeRandomSpikes(format, sparse):
sorting_analyzer = get_sorting_analyzer(format=format, sparse=sparse)
def test_ComputeRandomSpikes(format, sparse, create_cache_folder):
cache_folder = create_cache_folder
sorting_analyzer = get_sorting_analyzer(cache_folder, format=format, sparse=sparse)

ext = sorting_analyzer.compute("random_spikes", max_spikes_per_unit=10, seed=2205)
indices = ext.data["random_spikes_indices"]
assert indices.size == 10 * sorting_analyzer.sorting.unit_ids.size

_check_result_extension(sorting_analyzer, "random_spikes")
_check_result_extension(sorting_analyzer, "random_spikes", cache_folder)
sorting_analyzer.delete_extension("random_spikes")

ext = sorting_analyzer.compute("random_spikes", method="all")
indices = ext.data["random_spikes_indices"]
assert indices.size == len(sorting_analyzer.sorting.to_spike_vector())

_check_result_extension(sorting_analyzer, "random_spikes")
_check_result_extension(sorting_analyzer, "random_spikes", cache_folder)


@pytest.mark.parametrize("format", ["memory", "binary_folder", "zarr"])
@pytest.mark.parametrize("sparse", [True, False])
def test_ComputeWaveforms(format, sparse):
sorting_analyzer = get_sorting_analyzer(format=format, sparse=sparse)
def test_ComputeWaveforms(format, sparse, create_cache_folder):
cache_folder = create_cache_folder
sorting_analyzer = get_sorting_analyzer(cache_folder, format=format, sparse=sparse)

job_kwargs = dict(n_jobs=2, chunk_duration="1s", progress_bar=True)
sorting_analyzer.compute("random_spikes", max_spikes_per_unit=50, seed=2205)
ext = sorting_analyzer.compute("waveforms", **job_kwargs)
wfs = ext.data["waveforms"]
_check_result_extension(sorting_analyzer, "waveforms")
_check_result_extension(sorting_analyzer, "waveforms", cache_folder)


@pytest.mark.parametrize("format", ["memory", "binary_folder", "zarr"])
@pytest.mark.parametrize("sparse", [True, False])
def test_ComputeTemplates(format, sparse):
sorting_analyzer = get_sorting_analyzer(format=format, sparse=sparse)
def test_ComputeTemplates(format, sparse, create_cache_folder):
cache_folder = create_cache_folder
sorting_analyzer = get_sorting_analyzer(cache_folder, format=format, sparse=sparse)

sorting_analyzer.compute("random_spikes", max_spikes_per_unit=50, seed=2205)

Expand Down Expand Up @@ -187,13 +184,14 @@ def test_ComputeTemplates(format, sparse):
# ax.legend()
# plt.show()

_check_result_extension(sorting_analyzer, "templates")
_check_result_extension(sorting_analyzer, "templates", cache_folder)


@pytest.mark.parametrize("format", ["memory", "binary_folder", "zarr"])
@pytest.mark.parametrize("sparse", [True, False])
def test_ComputeNoiseLevels(format, sparse):
sorting_analyzer = get_sorting_analyzer(format=format, sparse=sparse)
def test_ComputeNoiseLevels(format, sparse, create_cache_folder):
cache_folder = create_cache_folder
sorting_analyzer = get_sorting_analyzer(cache_folder, format=format, sparse=sparse)

sorting_analyzer.compute("noise_levels")
print(sorting_analyzer)
Expand All @@ -212,8 +210,9 @@ def test_get_children_dependencies():
assert rs_children.index("waveforms") < rs_children.index("templates")


def test_delete_on_recompute():
sorting_analyzer = get_sorting_analyzer(format="memory", sparse=False)
def test_delete_on_recompute(create_cache_folder):
cache_folder = create_cache_folder
sorting_analyzer = get_sorting_analyzer(cache_folder, format="memory", sparse=False)
sorting_analyzer.compute("random_spikes")
sorting_analyzer.compute("waveforms")
sorting_analyzer.compute("templates")
Expand All @@ -224,8 +223,9 @@ def test_delete_on_recompute():
assert sorting_analyzer.get_extension("waveforms") is None


def test_compute_several():
sorting_analyzer = get_sorting_analyzer(format="memory", sparse=False)
def test_compute_several(create_cache_folder):
cache_folder = create_cache_folder
sorting_analyzer = get_sorting_analyzer(cache_folder, format="memory", sparse=False)

# should raise an error since waveforms depends on random_spikes, which isn't calculated
with pytest.raises(AssertionError):
Expand Down
9 changes: 2 additions & 7 deletions src/spikeinterface/core/tests/test_baserecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,9 @@

from spikeinterface.core import generate_recording

if hasattr(pytest, "global_test_folder"):
cache_folder = pytest.global_test_folder / "core"
else:
cache_folder = Path("cache_folder") / "core"
cache_folder.mkdir(exist_ok=True, parents=True)


def test_BaseRecording():
def test_BaseRecording(create_cache_folder):
cache_folder = create_cache_folder
num_seg = 2
num_chan = 3
num_samples = 30
Expand Down
9 changes: 2 additions & 7 deletions src/spikeinterface/core/tests/test_basesnippets.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,9 @@
from spikeinterface.core.npysnippetsextractor import NpySnippetsExtractor
from spikeinterface.core.base import BaseExtractor

if hasattr(pytest, "global_test_folder"):
cache_folder = pytest.global_test_folder / "core"
else:
cache_folder = Path("cache_folder") / "core"
cache_folder.mkdir(exist_ok=True, parents=True)


def test_BaseSnippets():
def test_BaseSnippets(create_cache_folder):
cache_folder = create_cache_folder
duration = [4, 3]
num_channels = 3
nbefore = 20
Expand Down
8 changes: 2 additions & 6 deletions src/spikeinterface/core/tests/test_basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,9 @@
from spikeinterface.core.testing import check_sorted_arrays_equal, check_sortings_equal
from spikeinterface.core.generate import generate_sorting

if hasattr(pytest, "global_test_folder"):
cache_folder = pytest.global_test_folder / "core"
else:
cache_folder = Path("cache_folder") / "core"


def test_BaseSorting():
def test_BaseSorting(create_cache_folder):
cache_folder = create_cache_folder
num_seg = 2
file_path = cache_folder / "test_BaseSorting.npz"
file_path.parent.mkdir(exist_ok=True)
Expand Down
9 changes: 2 additions & 7 deletions src/spikeinterface/core/tests/test_binaryfolder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,8 @@
from spikeinterface.core import generate_recording


if hasattr(pytest, "global_test_folder"):
cache_folder = pytest.global_test_folder / "core"
else:
cache_folder = Path("cache_folder") / "core"


def test_BinaryFolderRecording():
def test_BinaryFolderRecording(create_cache_folder):
cache_folder = create_cache_folder
rec = generate_recording(num_channels=10, durations=[2.0, 2.0])
folder = cache_folder / "binary_folder_1"

Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,12 @@
import pytest
import numpy as np
from pathlib import Path

from spikeinterface.core import BinaryRecordingExtractor
from spikeinterface.core.numpyextractors import NumpyRecording

if hasattr(pytest, "global_test_folder"):
cache_folder = pytest.global_test_folder / "core"
else:
cache_folder = Path("cache_folder") / "core"


def test_BinaryRecordingExtractor():
def test_BinaryRecordingExtractor(create_cache_folder):
cache_folder = create_cache_folder
num_seg = 2
num_channels = 3
num_samples = 30
Expand Down
7 changes: 2 additions & 5 deletions src/spikeinterface/core/tests/test_channelslicerecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,8 @@
from spikeinterface.core.generate import generate_recording


def test_ChannelSliceRecording():
if hasattr(pytest, "global_test_folder"):
cache_folder = pytest.global_test_folder / "core"
else:
cache_folder = Path("cache_folder") / "core"
def test_ChannelSliceRecording(create_cache_folder):
cache_folder = create_cache_folder

num_seg = 2
num_chan = 3
Expand Down
Loading

0 comments on commit a9fe858

Please sign in to comment.