Skip to content

Commit

Permalink
Merge pull request #2068 from samuelgarcia/widgets
Browse files Browse the repository at this point in the history
Move more widgets to new widgets API
  • Loading branch information
samuelgarcia authored Oct 18, 2023
2 parents cb219c7 + 4da65ed commit 64046de
Show file tree
Hide file tree
Showing 8 changed files with 146 additions and 216 deletions.
8 changes: 0 additions & 8 deletions src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,6 @@
# isi/ccg/acg
from .isidistribution import plot_isi_distribution, ISIDistributionWidget

# peak activity
from .activity import plot_peak_activity_map, PeakActivityMapWidget

# waveform/PC related
from .principalcomponent import plot_principal_component

# units on probe
from .unitprobemap import plot_unit_probe_map, UnitProbeMapWidget

from .multicompgraph import (
plot_multicomp_graph,
Expand Down
112 changes: 0 additions & 112 deletions src/spikeinterface/widgets/_legacy_mpl_widgets/isidistribution.py

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ def setUp(self):
def tearDown(self):
pass

def test_plot_unit_probe_map(self):
sw.plot_unit_probe_map(self._we, with_channel_ids=True)
sw.plot_unit_probe_map(self._we, animated=True)
# def test_plot_unit_probe_map(self):
# sw.plot_unit_probe_map(self._we, with_channel_ids=True)
# sw.plot_unit_probe_map(self._we, animated=True)

# def test_plot_units_depth_vs_amplitude(self):
# sw.plot_units_depth_vs_amplitude(self._we)
Expand All @@ -58,9 +58,6 @@ def test_plot_unit_probe_map(self):
# def test_amplitudes_distribution(self):
# sw.plot_amplitudes_distribution(self._we)

def test_principal_component(self):
sw.plot_principal_component(self._we)

# def test_plot_unit_localization(self):
# sw.plot_unit_localization(self._we, with_channel_ids=True)
# sw.plot_unit_localization(self._we, method='monopolar_triangulation')
Expand All @@ -73,10 +70,10 @@ def test_principal_component(self):
# unit_ids = self._sorting.unit_ids[:4]
# sw.plot_crosscorrelograms(self._sorting, unit_ids=unit_ids, window_ms=500.0, bin_ms=20.0)

def test_isi_distribution(self):
sw.plot_isi_distribution(self._sorting, bin_ms=5.0, window_ms=500.0)
fig, axes = plt.subplots(self.num_units, 1)
sw.plot_isi_distribution(self._sorting, axes=axes)
# def test_isi_distribution(self):
# sw.plot_isi_distribution(self._sorting, bin_ms=5.0, window_ms=500.0)
# fig, axes = plt.subplots(self.num_units, 1)
# sw.plot_isi_distribution(self._sorting, axes=axes)

def test_plot_peak_activity_map(self):
sw.plot_peak_activity_map(self._rec, with_channel_ids=True)
Expand Down
71 changes: 71 additions & 0 deletions src/spikeinterface/widgets/isi_distribution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import numpy as np
from warnings import warn

from .base import BaseWidget, to_attr
from .utils import get_unit_colors


class ISIDistributionWidget(BaseWidget):
"""
Plots spike train ISI distribution.
Parameters
----------
sorting: SortingExtractor
The sorting extractor object
unit_ids: list
List of unit ids
bins_ms: int
Bin size in ms
window_ms: float
Window size in ms
"""

def __init__(self, sorting, unit_ids=None, window_ms=100.0, bin_ms=1.0, backend=None, **backend_kwargs):
if unit_ids is None:
unit_ids = sorting.get_unit_ids()

plot_data = dict(
sorting=sorting,
unit_ids=unit_ids,
window_ms=window_ms,
bin_ms=bin_ms,
)

BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs)

def plot_matplotlib(self, data_plot, **backend_kwargs):
import matplotlib.pyplot as plt
from .utils_matplotlib import make_mpl_figure

dp = to_attr(data_plot)

if backend_kwargs.get("axes", None) is None:
backend_kwargs["num_axes"] = len(dp.unit_ids)

self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs)

sorting = dp.sorting
num_segments = sorting.get_num_segments()
fs = sorting.sampling_frequency

for i, unit_id in enumerate(dp.unit_ids):
ax = self.axes.flatten()[i]

bins = np.arange(0, dp.window_ms, dp.bin_ms)
bin_counts = None
for segment_index in range(num_segments):
times_ms = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) / fs * 1000.0
isi = np.diff(times_ms)

bin_counts_, bin_edges = np.histogram(isi, bins=bins, density=True)
if segment_index == 0:
bin_counts = bin_counts_
else:
bin_counts += bin_counts_
# TODO handle sensity when several segments

ax.bar(x=bin_edges[:-1], height=bin_counts, width=dp.bin_ms, color="gray", align="edge")

ax.set_ylabel(f"{unit_id}")
24 changes: 23 additions & 1 deletion src/spikeinterface/widgets/tests/test_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,20 @@ def test_crosscorrelogram(self):
**self.backend_kwargs[backend],
)

def test_isi_distribution(self):
possible_backends = list(sw.ISIDistributionWidget.get_possible_backends())
for backend in possible_backends:
if backend not in self.skip_backends:
unit_ids = self.sorting.unit_ids[:4]
sw.plot_isi_distribution(
self.sorting,
unit_ids=unit_ids,
window_ms=25.0,
bin_ms=2.0,
backend=backend,
**self.backend_kwargs[backend],
)

def test_amplitudes(self):
possible_backends = list(sw.AmplitudesWidget.get_possible_backends())
for backend in possible_backends:
Expand Down Expand Up @@ -357,6 +371,12 @@ def test_plot_rasters(self):
if backend not in self.skip_backends:
sw.plot_rasters(self.sorting)

def test_plot_unit_probe_map(self):
possible_backends = list(sw.UnitProbeMapWidget.get_possible_backends())
for backend in possible_backends:
if backend not in self.skip_backends:
sw.plot_unit_probe_map(self.we_dense)


if __name__ == "__main__":
# unittest.main()
Expand All @@ -374,6 +394,8 @@ def test_plot_rasters(self):
# mytest.test_plot_unit_depths()
# mytest.test_plot_unit_templates()
# mytest.test_plot_unit_summary()
# mytest.test_crosscorrelogram()
mytest.test_isi_distribution()
# mytest.test_unit_locations()
# mytest.test_quality_metrics()
mytest.test_template_metrics()
Expand All @@ -382,6 +404,6 @@ def test_plot_rasters(self):
# mytest.test_plot_confusion_matrix()
# mytest.test_plot_probe_map()
# mytest.test_plot_rasters()
# mytest.test_plot_unit_probe_map()

# plt.ion()
plt.show()
Loading

0 comments on commit 64046de

Please sign in to comment.