Skip to content

Commit

Permalink
Merge pull request #3249 from chrishalcrow/simplify-qm-tests
Browse files Browse the repository at this point in the history
Refactor quality metrics tests to use fixture
  • Loading branch information
alejoe91 authored Sep 12, 2024
2 parents 73f4d58 + 41f73ed commit 3d725a4
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 77 deletions.
40 changes: 37 additions & 3 deletions src/spikeinterface/qualitymetrics/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@
create_sorting_analyzer,
)

job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s")

def _small_sorting_analyzer():

@pytest.fixture(scope="module")
def small_sorting_analyzer():
recording, sorting = generate_ground_truth_recording(
durations=[2.0],
num_units=10,
Expand All @@ -33,5 +36,36 @@ def _small_sorting_analyzer():


@pytest.fixture(scope="module")
def small_sorting_analyzer():
return _small_sorting_analyzer()
def sorting_analyzer_simple():
# we need high firing rate for amplitude_cutoff
recording, sorting = generate_ground_truth_recording(
durations=[
120.0,
],
sampling_frequency=30_000.0,
num_channels=6,
num_units=10,
generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0),
generate_unit_locations_kwargs=dict(
margin_um=5.0,
minimum_z=5.0,
maximum_z=20.0,
),
generate_templates_kwargs=dict(
unit_params=dict(
alpha=(200.0, 500.0),
)
),
noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"),
seed=1205,
)

sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=True)

sorting_analyzer.compute("random_spikes", max_spikes_per_unit=300, seed=1205)
sorting_analyzer.compute("noise_levels")
sorting_analyzer.compute("waveforms", **job_kwargs)
sorting_analyzer.compute("templates")
sorting_analyzer.compute("spike_amplitudes", **job_kwargs)

return sorting_analyzer
Original file line number Diff line number Diff line change
Expand Up @@ -135,36 +135,6 @@ def test_unit_id_order_independence(small_sorting_analyzer):
assert quality_metrics_2[metric][1] == metric_1_data["#4"]


def _sorting_analyzer_simple():
recording, sorting = generate_ground_truth_recording(
durations=[
50.0,
],
sampling_frequency=30_000.0,
num_channels=6,
num_units=10,
generate_sorting_kwargs=dict(firing_rates=6.0, refractory_period_ms=4.0),
noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"),
seed=2205,
)

sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=True)

sorting_analyzer.compute("random_spikes", max_spikes_per_unit=300, seed=2205)
sorting_analyzer.compute("noise_levels")
sorting_analyzer.compute("waveforms", **job_kwargs)
sorting_analyzer.compute("templates")
sorting_analyzer.compute("principal_components", n_components=5, mode="by_channel_local", **job_kwargs)
sorting_analyzer.compute("spike_amplitudes", **job_kwargs)

return sorting_analyzer


@pytest.fixture(scope="module")
def sorting_analyzer_simple():
return _sorting_analyzer_simple()


def _sorting_violation():
max_time = 100.0
sampling_frequency = 30000
Expand Down Expand Up @@ -576,6 +546,7 @@ def test_calculate_sd_ratio(sorting_analyzer_simple):
test_unit_structure_in_output(_small_sorting_analyzer())

# test_calculate_firing_rate_num_spikes(sorting_analyzer)

# test_calculate_snrs(sorting_analyzer)
# test_calculate_amplitude_cutoff(sorting_analyzer)
# test_calculate_presence_ratio(sorting_analyzer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from pathlib import Path
import numpy as np


from spikeinterface.core import (
generate_ground_truth_recording,
create_sorting_analyzer,
Expand All @@ -15,54 +14,11 @@
compute_quality_metrics,
)


job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s")


def get_sorting_analyzer(seed=2205):
# we need high firing rate for amplitude_cutoff
recording, sorting = generate_ground_truth_recording(
durations=[
120.0,
],
sampling_frequency=30_000.0,
num_channels=6,
num_units=10,
generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0),
generate_unit_locations_kwargs=dict(
margin_um=5.0,
minimum_z=5.0,
maximum_z=20.0,
),
generate_templates_kwargs=dict(
unit_params=dict(
alpha=(200.0, 500.0),
)
),
noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"),
seed=seed,
)

sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=True)

sorting_analyzer.compute("random_spikes", max_spikes_per_unit=300, seed=seed)
sorting_analyzer.compute("noise_levels")
sorting_analyzer.compute("waveforms", **job_kwargs)
sorting_analyzer.compute("templates")
sorting_analyzer.compute("spike_amplitudes", **job_kwargs)

return sorting_analyzer


@pytest.fixture(scope="module")
def sorting_analyzer_simple():
sorting_analyzer = get_sorting_analyzer(seed=2205)
return sorting_analyzer


def test_compute_quality_metrics(sorting_analyzer_simple):
sorting_analyzer = sorting_analyzer_simple
print(sorting_analyzer)

# without PCs
metrics = compute_quality_metrics(
Expand Down

0 comments on commit 3d725a4

Please sign in to comment.