Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelgarcia committed Nov 22, 2024
1 parent f0ec139 commit cc8b4c4
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 3 deletions.
7 changes: 5 additions & 2 deletions src/spikeinterface/qualitymetrics/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s")


@pytest.fixture(scope="module")
def small_sorting_analyzer():

def make_small_analyzer():
recording, sorting = generate_ground_truth_recording(
durations=[2.0],
num_units=10,
Expand All @@ -34,6 +34,9 @@ def small_sorting_analyzer():

return sorting_analyzer

@pytest.fixture(scope="module")
def small_sorting_analyzer():
return make_small_analyzer()

@pytest.fixture(scope="module")
def sorting_analyzer_simple():
Expand Down
14 changes: 13 additions & 1 deletion src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,14 @@ def test_calculate_pc_metrics(small_sorting_analyzer):
assert not np.all(np.isnan(res1[metric_name].values))
assert not np.all(np.isnan(res2[metric_name].values))

assert np.array_equal(res1[metric_name].values, res2[metric_name].values)
# import matplotlib.pyplot as plt
# fig, ax = plt.subplots()
# ax.plot(res1[metric_name].values)
# ax.plot(res2[metric_name].values)
# ax.plot(res2[metric_name].values - res1[metric_name].values)
# plt.show()

np.testing.assert_almost_equal(res1[metric_name].values, res2[metric_name].values, decimal=4)


def test_pca_metrics_multi_processing(small_sorting_analyzer):
Expand All @@ -41,3 +48,8 @@ def test_pca_metrics_multi_processing(small_sorting_analyzer):
res2 = compute_pc_metrics(
sorting_analyzer, n_jobs=-1, metric_names=metric_names, max_threads_per_worker=2, progress_bar=True
)

if __name__ == "__main__":
from spikeinterface.qualitymetrics.tests.conftest import make_small_analyzer
small_sorting_analyzer = make_small_analyzer()
test_calculate_pc_metrics(small_sorting_analyzer)

0 comments on commit cc8b4c4

Please sign in to comment.