Skip to content

Commit

Permalink
Merge pull request #2121 from alejoe91/fix-failing-qm-tests
Browse files Browse the repository at this point in the history
Use default cutouts for peak-sign test
  • Loading branch information
alejoe91 authored Oct 23, 2023
2 parents 1c6535a + 0159221 commit 399fa54
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 28 deletions.
15 changes: 6 additions & 9 deletions src/spikeinterface/core/tests/test_template_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import shutil
from pathlib import Path

from spikeinterface import WaveformExtractor, load_extractor, extract_waveforms, generate_recording, generate_sorting
from spikeinterface import load_extractor, extract_waveforms, load_waveforms, generate_recording, generate_sorting

from spikeinterface.core import (
get_template_amplitudes,
Expand Down Expand Up @@ -33,25 +33,23 @@ def setup_module():
sorting.set_property("group", [0, 0, 0, 0, 1, 1, 1, 1, 1, 1])
sorting = sorting.save(folder=cache_folder / "toy_sort")

we = WaveformExtractor.create(recording, sorting, cache_folder / "toy_waveforms")
we.set_params(ms_before=3.0, ms_after=4.0, max_spikes_per_unit=500)
we.run_extract_waveforms(n_jobs=1, chunk_size=30000)
we = extract_waveforms(recording, sorting, cache_folder / "toy_waveforms")


def test_get_template_amplitudes():
we = WaveformExtractor.load(cache_folder / "toy_waveforms")
we = load_waveforms(cache_folder / "toy_waveforms")
peak_values = get_template_amplitudes(we)
print(peak_values)


def test_get_template_extremum_channel():
we = WaveformExtractor.load(cache_folder / "toy_waveforms")
we = load_waveforms(cache_folder / "toy_waveforms")
extremum_channels_ids = get_template_extremum_channel(we, peak_sign="both")
print(extremum_channels_ids)


def test_get_template_extremum_channel_peak_shift():
we = WaveformExtractor.load(cache_folder / "toy_waveforms")
we = load_waveforms(cache_folder / "toy_waveforms")
shifts = get_template_extremum_channel_peak_shift(we, peak_sign="neg")
print(shifts)

Expand All @@ -72,7 +70,7 @@ def test_get_template_extremum_channel_peak_shift():


def test_get_template_extremum_amplitude():
we = WaveformExtractor.load(cache_folder / "toy_waveforms")
we = load_waveforms(cache_folder / "toy_waveforms")

extremum_channels_ids = get_template_extremum_amplitude(we, peak_sign="both")
print(extremum_channels_ids)
Expand All @@ -85,4 +83,3 @@ def test_get_template_extremum_amplitude():
test_get_template_extremum_channel()
test_get_template_extremum_channel_peak_shift()
test_get_template_extremum_amplitude()
test_get_template_channel_sparsity()
23 changes: 13 additions & 10 deletions src/spikeinterface/core/tests/test_waveform_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,14 +388,17 @@ def test_unfiltered_extraction():
shutil.rmtree(wf_folder)
we = WaveformExtractor.create(recording, sorting, wf_folder, mode=mode, allow_unfiltered=True)

we.set_params(ms_before=3.0, ms_after=4.0, max_spikes_per_unit=500)

ms_before = 2.0
ms_after = 3.0
max_spikes_per_unit = 500
num_samples = int((ms_before + ms_after) * sampling_frequency / 1000.0)
we.set_params(ms_before=ms_before, ms_after=ms_after, max_spikes_per_unit=max_spikes_per_unit)
we.run_extract_waveforms(n_jobs=1, chunk_size=30000)
we.run_extract_waveforms(n_jobs=4, chunk_size=30000, progress_bar=True)

wfs = we.get_waveforms(0)
assert wfs.shape[0] <= 500
assert wfs.shape[1:] == (210, num_channels)
assert wfs.shape[0] <= max_spikes_per_unit
assert wfs.shape[1:] == (num_samples, num_channels)

wfs, sampled_index = we.get_waveforms(0, with_index=True)

Expand All @@ -406,18 +409,18 @@ def test_unfiltered_extraction():
wfs = we.get_waveforms(0)

template = we.get_template(0)
assert template.shape == (210, 2)
assert template.shape == (num_samples, 2)
templates = we.get_all_templates()
assert templates.shape == (num_units, 210, num_channels)
assert templates.shape == (num_units, num_samples, num_channels)

wf_std = we.get_template(0, mode="std")
assert wf_std.shape == (210, num_channels)
assert wf_std.shape == (num_samples, num_channels)
wfs_std = we.get_all_templates(mode="std")
assert wfs_std.shape == (num_units, 210, num_channels)
assert wfs_std.shape == (num_units, num_samples, num_channels)

wf_segment = we.get_template_segment(unit_id=0, segment_index=0)
assert wf_segment.shape == (210, num_channels)
assert wf_segment.shape == (210, num_channels)
assert wf_segment.shape == (num_samples, num_channels)
assert wf_segment.shape == (num_samples, num_channels)


def test_portability():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,6 @@ def setUp(self):
recording,
sorting,
cache_folder / "toy_waveforms_1seg",
ms_before=3.0,
ms_after=4.0,
max_spikes_per_unit=500,
sparse=False,
n_jobs=1,
Expand Down Expand Up @@ -90,8 +88,6 @@ def setUp(self):
recording,
sorting,
cache_folder / "toy_waveforms_2seg",
ms_before=3.0,
ms_after=4.0,
max_spikes_per_unit=500,
sparse=False,
n_jobs=1,
Expand All @@ -115,8 +111,6 @@ def setUp(self):
sorting,
mode="memory",
sparse=False,
ms_before=3.0,
ms_after=4.0,
max_spikes_per_unit=500,
n_jobs=1,
chunk_size=30000,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,7 @@ def test_peak_sign(self):
# invert recording
rec_inv = scale(rec, gain=-1.0)

we_inv = WaveformExtractor.create(rec_inv, sort, self.cache_folder / "toy_waveforms_inv")
we_inv.set_params(ms_before=3.0, ms_after=4.0, max_spikes_per_unit=None)
we_inv.run_extract_waveforms(n_jobs=1, chunk_size=30000)
we_inv = extract_waveforms(rec_inv, sort, self.cache_folder / "toy_waveforms_inv")

# compute amplitudes
_ = compute_spike_amplitudes(we, peak_sign="neg")
Expand Down

0 comments on commit 399fa54

Please sign in to comment.