Skip to content

Commit

Permalink
Merge branch 'main' of github.com:SpikeInterface/spikeinterface into …
Browse files Browse the repository at this point in the history
…fix_plot_traces
  • Loading branch information
samuelgarcia committed Oct 6, 2023
2 parents c50de4f + d65978e commit cd83ed3
Show file tree
Hide file tree
Showing 12 changed files with 77 additions and 54 deletions.
6 changes: 5 additions & 1 deletion src/spikeinterface/core/sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,11 @@ def __init__(self, mask, unit_ids, channel_ids):

self.num_channels = self.channel_ids.size
self.num_units = self.unit_ids.size
self.max_num_active_channels = self.mask.sum(axis=1).max()
if self.mask.shape[0]:
self.max_num_active_channels = self.mask.sum(axis=1).max()
else:
# empty sorting without units
self.max_num_active_channels = 0

def __repr__(self):
density = np.mean(self.mask)
Expand Down
1 change: 1 addition & 0 deletions src/spikeinterface/core/tests/test_waveform_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,3 +559,4 @@ def test_non_json_object():
test_recordingless()
# test_compute_sparsity()
# test_non_json_object()
test_empty_sorting()
9 changes: 5 additions & 4 deletions src/spikeinterface/core/waveform_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1457,13 +1457,13 @@ def extract_waveforms(
folder=None,
mode="folder",
precompute_template=("average",),
ms_before=3.0,
ms_after=4.0,
ms_before=1.0,
ms_after=2.0,
max_spikes_per_unit=500,
overwrite=False,
return_scaled=True,
dtype=None,
sparse=False,
sparse=True,
sparsity=None,
num_spikes_for_sparsity=100,
allow_unfiltered=False,
Expand Down Expand Up @@ -1507,7 +1507,7 @@ def extract_waveforms(
If True and recording has gain_to_uV/offset_to_uV properties, waveforms are converted to uV.
dtype: dtype or None
Dtype of the output waveforms. If None, the recording dtype is maintained.
sparse: bool (default False)
sparse: bool, default: True
If True, before extracting all waveforms the `precompute_sparsity()` function is run using
a few spikes to get an estimate of dense templates to create a ChannelSparsity object.
Then, the waveforms will be sparse at extraction time, which saves a lot of memory.
Expand Down Expand Up @@ -1726,6 +1726,7 @@ def precompute_sparsity(
max_spikes_per_unit=num_spikes_for_sparsity,
return_scaled=False,
allow_unfiltered=allow_unfiltered,
sparse=False,
**job_kwargs,
)
local_sparsity = compute_sparsity(local_we, **sparse_kwargs)
Expand Down
6 changes: 3 additions & 3 deletions src/spikeinterface/exporters/tests/test_export_to_phy.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def test_export_to_phy_by_property():
recording = recording.save(folder=rec_folder)
sorting = sorting.save(folder=sort_folder)

waveform_extractor = extract_waveforms(recording, sorting, waveform_folder)
waveform_extractor = extract_waveforms(recording, sorting, waveform_folder, sparse=False)
sparsity_group = compute_sparsity(waveform_extractor, method="by_property", by_property="group")
export_to_phy(
waveform_extractor,
Expand All @@ -96,7 +96,7 @@ def test_export_to_phy_by_property():

# Remove one channel
recording_rm = recording.channel_slice([0, 2, 3, 4, 5, 6, 7])
waveform_extractor_rm = extract_waveforms(recording_rm, sorting, waveform_folder_rm)
waveform_extractor_rm = extract_waveforms(recording_rm, sorting, waveform_folder_rm, sparse=False)
sparsity_group = compute_sparsity(waveform_extractor_rm, method="by_property", by_property="group")

export_to_phy(
Expand Down Expand Up @@ -130,7 +130,7 @@ def test_export_to_phy_by_sparsity():
if f.is_dir():
shutil.rmtree(f)

waveform_extractor = extract_waveforms(recording, sorting, waveform_folder)
waveform_extractor = extract_waveforms(recording, sorting, waveform_folder, sparse=False)
sparsity_radius = compute_sparsity(waveform_extractor, method="radius", radius_um=50.0)
export_to_phy(
waveform_extractor,
Expand Down
1 change: 1 addition & 0 deletions src/spikeinterface/exporters/to_phy.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def export_to_phy(

if waveform_extractor.is_sparse():
used_sparsity = waveform_extractor.sparsity
assert sparsity is None
elif sparsity is not None:
used_sparsity = sparsity
else:
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/postprocessing/correlograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ def compute_crosscorrelogram_from_spiketrain(spike_times1, spike_times2, window_
def compute_correlograms(
waveform_or_sorting_extractor,
load_if_exists=False,
window_ms: float = 100.0,
bin_ms: float = 5.0,
window_ms: float = 50.0,
bin_ms: float = 1.0,
method: str = "auto",
):
"""Compute auto and cross correlograms.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def setUp(self):
ms_before=3.0,
ms_after=4.0,
max_spikes_per_unit=500,
sparse=False,
n_jobs=1,
chunk_size=30000,
overwrite=True,
Expand Down Expand Up @@ -92,6 +93,7 @@ def setUp(self):
ms_before=3.0,
ms_after=4.0,
max_spikes_per_unit=500,
sparse=False,
n_jobs=1,
chunk_size=30000,
overwrite=True,
Expand All @@ -112,6 +114,7 @@ def setUp(self):
recording,
sorting,
mode="memory",
sparse=False,
ms_before=3.0,
ms_after=4.0,
max_spikes_per_unit=500,
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/postprocessing/unit_localization.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def get_extension_function():


def compute_unit_locations(
waveform_extractor, load_if_exists=False, method="center_of_mass", outputs="numpy", **method_kwargs
waveform_extractor, load_if_exists=False, method="monopolar_triangulation", outputs="numpy", **method_kwargs
):
"""
Localize units in 2D or 3D with several methods given the template.
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/sorters/runsorter.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def run_sorter(
sorter_name: str,
recording: BaseRecording,
output_folder: Optional[str] = None,
remove_existing_folder: bool = True,
remove_existing_folder: bool = False,
delete_output_folder: bool = False,
verbose: bool = False,
raise_error: bool = True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def plot(self):
self._do_plot()

def _do_plot(self):
from matplotlib import pyplot as plt

fig = self.figure

for ax in fig.axes:
Expand Down Expand Up @@ -177,6 +179,8 @@ def plot(self):

def _do_plot(self):
import sklearn
import matplotlib.pyplot as plt
import matplotlib

# compute similarity
# take index of template (respect unit_ids order)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ def setUp(self):

self.num_units = len(self._sorting.get_unit_ids())
#  self._we = extract_waveforms(self._rec, self._sorting, './toy_example', load_if_exists=True)
if (cache_folder / "mearec_test").is_dir():
self._we = load_waveforms(cache_folder / "mearec_test")
if (cache_folder / "mearec_test_old_api").is_dir():
self._we = load_waveforms(cache_folder / "mearec_test_old_api")
else:
self._we = extract_waveforms(self._rec, self._sorting, cache_folder / "mearec_test")
self._we = extract_waveforms(self._rec, self._sorting, cache_folder / "mearec_test_old_api", sparse=False)

self._amplitudes = compute_spike_amplitudes(self._we, peak_sign="neg", outputs="by_unit")
self._gt_comp = sc.compare_sorter_to_ground_truth(self._sorting, self._sorting)
Expand Down
87 changes: 48 additions & 39 deletions src/spikeinterface/widgets/tests/test_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,29 +48,30 @@ def setUpClass(cls):
cls.sorting = se.MEArecSortingExtractor(local_path)

cls.num_units = len(cls.sorting.get_unit_ids())
if (cache_folder / "mearec_test").is_dir():
cls.we = load_waveforms(cache_folder / "mearec_test")
if (cache_folder / "mearec_test_dense").is_dir():
cls.we_dense = load_waveforms(cache_folder / "mearec_test_dense")
else:
cls.we = extract_waveforms(cls.recording, cls.sorting, cache_folder / "mearec_test")
cls.we_dense = extract_waveforms(
cls.recording, cls.sorting, cache_folder / "mearec_test_dense", sparse=False
)
metric_names = ["snr", "isi_violation", "num_spikes"]
_ = compute_spike_amplitudes(cls.we_dense)
_ = compute_unit_locations(cls.we_dense)
_ = compute_spike_locations(cls.we_dense)
_ = compute_quality_metrics(cls.we_dense, metric_names=metric_names)
_ = compute_template_metrics(cls.we_dense)
_ = compute_correlograms(cls.we_dense)
_ = compute_template_similarity(cls.we_dense)

sw.set_default_plotter_backend("matplotlib")

metric_names = ["snr", "isi_violation", "num_spikes"]
_ = compute_spike_amplitudes(cls.we)
_ = compute_unit_locations(cls.we)
_ = compute_spike_locations(cls.we)
_ = compute_quality_metrics(cls.we, metric_names=metric_names)
_ = compute_template_metrics(cls.we)
_ = compute_correlograms(cls.we)
_ = compute_template_similarity(cls.we)

# make sparse waveforms
cls.sparsity_radius = compute_sparsity(cls.we, method="radius", radius_um=50)
cls.sparsity_best = compute_sparsity(cls.we, method="best_channels", num_channels=5)
cls.sparsity_radius = compute_sparsity(cls.we_dense, method="radius", radius_um=50)
cls.sparsity_best = compute_sparsity(cls.we_dense, method="best_channels", num_channels=5)
if (cache_folder / "mearec_test_sparse").is_dir():
cls.we_sparse = load_waveforms(cache_folder / "mearec_test_sparse")
else:
cls.we_sparse = cls.we.save(folder=cache_folder / "mearec_test_sparse", sparsity=cls.sparsity_radius)
cls.we_sparse = cls.we_dense.save(folder=cache_folder / "mearec_test_sparse", sparsity=cls.sparsity_radius)

cls.skip_backends = ["ipywidgets", "ephyviewer"]

Expand Down Expand Up @@ -124,17 +125,17 @@ def test_plot_unit_waveforms(self):
possible_backends = list(sw.UnitWaveformsWidget.get_possible_backends())
for backend in possible_backends:
if backend not in self.skip_backends:
sw.plot_unit_waveforms(self.we, backend=backend, **self.backend_kwargs[backend])
sw.plot_unit_waveforms(self.we_dense, backend=backend, **self.backend_kwargs[backend])
unit_ids = self.sorting.unit_ids[:6]
sw.plot_unit_waveforms(
self.we,
self.we_dense,
sparsity=self.sparsity_radius,
unit_ids=unit_ids,
backend=backend,
**self.backend_kwargs[backend],
)
sw.plot_unit_waveforms(
self.we,
self.we_dense,
sparsity=self.sparsity_best,
unit_ids=unit_ids,
backend=backend,
Expand All @@ -148,10 +149,10 @@ def test_plot_unit_templates(self):
possible_backends = list(sw.UnitWaveformsWidget.get_possible_backends())
for backend in possible_backends:
if backend not in self.skip_backends:
sw.plot_unit_templates(self.we, backend=backend, **self.backend_kwargs[backend])
sw.plot_unit_templates(self.we_dense, backend=backend, **self.backend_kwargs[backend])
unit_ids = self.sorting.unit_ids[:6]
sw.plot_unit_templates(
self.we,
self.we_dense,
sparsity=self.sparsity_radius,
unit_ids=unit_ids,
backend=backend,
Expand All @@ -171,7 +172,7 @@ def test_plot_unit_waveforms_density_map(self):
if backend not in self.skip_backends:
unit_ids = self.sorting.unit_ids[:2]
sw.plot_unit_waveforms_density_map(
self.we, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend]
self.we_dense, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend]
)

def test_plot_unit_waveforms_density_map_sparsity_radius(self):
Expand All @@ -180,7 +181,7 @@ def test_plot_unit_waveforms_density_map_sparsity_radius(self):
if backend not in self.skip_backends:
unit_ids = self.sorting.unit_ids[:2]
sw.plot_unit_waveforms_density_map(
self.we,
self.we_dense,
sparsity=self.sparsity_radius,
same_axis=False,
unit_ids=unit_ids,
Expand Down Expand Up @@ -234,11 +235,15 @@ def test_amplitudes(self):
possible_backends = list(sw.AmplitudesWidget.get_possible_backends())
for backend in possible_backends:
if backend not in self.skip_backends:
sw.plot_amplitudes(self.we, backend=backend, **self.backend_kwargs[backend])
unit_ids = self.we.unit_ids[:4]
sw.plot_amplitudes(self.we, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend])
sw.plot_amplitudes(self.we_dense, backend=backend, **self.backend_kwargs[backend])
unit_ids = self.we_dense.unit_ids[:4]
sw.plot_amplitudes(self.we_dense, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend])
sw.plot_amplitudes(
self.we, unit_ids=unit_ids, plot_histograms=True, backend=backend, **self.backend_kwargs[backend]
self.we_dense,
unit_ids=unit_ids,
plot_histograms=True,
backend=backend,
**self.backend_kwargs[backend],
)
sw.plot_amplitudes(
self.we_sparse,
Expand All @@ -252,9 +257,9 @@ def test_plot_all_amplitudes_distributions(self):
possible_backends = list(sw.AllAmplitudesDistributionsWidget.get_possible_backends())
for backend in possible_backends:
if backend not in self.skip_backends:
unit_ids = self.we.unit_ids[:4]
unit_ids = self.we_dense.unit_ids[:4]
sw.plot_all_amplitudes_distributions(
self.we, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend]
self.we_dense, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend]
)
sw.plot_all_amplitudes_distributions(
self.we_sparse, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend]
Expand All @@ -264,7 +269,9 @@ def test_unit_locations(self):
possible_backends = list(sw.UnitLocationsWidget.get_possible_backends())
for backend in possible_backends:
if backend not in self.skip_backends:
sw.plot_unit_locations(self.we, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend])
sw.plot_unit_locations(
self.we_dense, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend]
)
sw.plot_unit_locations(
self.we_sparse, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend]
)
Expand All @@ -273,7 +280,9 @@ def test_spike_locations(self):
possible_backends = list(sw.SpikeLocationsWidget.get_possible_backends())
for backend in possible_backends:
if backend not in self.skip_backends:
sw.plot_spike_locations(self.we, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend])
sw.plot_spike_locations(
self.we_dense, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend]
)
sw.plot_spike_locations(
self.we_sparse, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend]
)
Expand All @@ -282,46 +291,46 @@ def test_similarity(self):
possible_backends = list(sw.TemplateSimilarityWidget.get_possible_backends())
for backend in possible_backends:
if backend not in self.skip_backends:
sw.plot_template_similarity(self.we, backend=backend, **self.backend_kwargs[backend])
sw.plot_template_similarity(self.we_dense, backend=backend, **self.backend_kwargs[backend])
sw.plot_template_similarity(self.we_sparse, backend=backend, **self.backend_kwargs[backend])

def test_quality_metrics(self):
possible_backends = list(sw.QualityMetricsWidget.get_possible_backends())
for backend in possible_backends:
if backend not in self.skip_backends:
sw.plot_quality_metrics(self.we, backend=backend, **self.backend_kwargs[backend])
sw.plot_quality_metrics(self.we_dense, backend=backend, **self.backend_kwargs[backend])
sw.plot_quality_metrics(self.we_sparse, backend=backend, **self.backend_kwargs[backend])

def test_template_metrics(self):
possible_backends = list(sw.TemplateMetricsWidget.get_possible_backends())
for backend in possible_backends:
if backend not in self.skip_backends:
sw.plot_template_metrics(self.we, backend=backend, **self.backend_kwargs[backend])
sw.plot_template_metrics(self.we_dense, backend=backend, **self.backend_kwargs[backend])
sw.plot_template_metrics(self.we_sparse, backend=backend, **self.backend_kwargs[backend])

def test_plot_unit_depths(self):
possible_backends = list(sw.UnitDepthsWidget.get_possible_backends())
for backend in possible_backends:
if backend not in self.skip_backends:
sw.plot_unit_depths(self.we, backend=backend, **self.backend_kwargs[backend])
sw.plot_unit_depths(self.we_dense, backend=backend, **self.backend_kwargs[backend])
sw.plot_unit_depths(self.we_sparse, backend=backend, **self.backend_kwargs[backend])

def test_plot_unit_summary(self):
possible_backends = list(sw.UnitSummaryWidget.get_possible_backends())
for backend in possible_backends:
if backend not in self.skip_backends:
sw.plot_unit_summary(
self.we, self.we.sorting.unit_ids[0], backend=backend, **self.backend_kwargs[backend]
self.we_dense, self.we_dense.sorting.unit_ids[0], backend=backend, **self.backend_kwargs[backend]
)
sw.plot_unit_summary(
self.we_sparse, self.we.sorting.unit_ids[0], backend=backend, **self.backend_kwargs[backend]
self.we_sparse, self.we_sparse.sorting.unit_ids[0], backend=backend, **self.backend_kwargs[backend]
)

def test_sorting_summary(self):
possible_backends = list(sw.SortingSummaryWidget.get_possible_backends())
for backend in possible_backends:
if backend not in self.skip_backends:
sw.plot_sorting_summary(self.we, backend=backend, **self.backend_kwargs[backend])
sw.plot_sorting_summary(self.we_dense, backend=backend, **self.backend_kwargs[backend])
sw.plot_sorting_summary(self.we_sparse, backend=backend, **self.backend_kwargs[backend])

def test_plot_agreement_matrix(self):
Expand Down Expand Up @@ -369,10 +378,10 @@ def test_plot_rasters(self):
# mytest.test_quality_metrics()
# mytest.test_template_metrics()
# mytest.test_amplitudes()
# mytest.test_plot_agreement_matrix()
mytest.test_plot_agreement_matrix()
# mytest.test_plot_confusion_matrix()
# mytest.test_plot_probe_map()
mytest.test_plot_rasters()
# mytest.test_plot_rasters()

# plt.ion()
plt.show()

0 comments on commit cd83ed3

Please sign in to comment.