From ce9a72c33823eaa51254f77607cd2c0c15691a53 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 28 Sep 2023 21:05:42 +0200 Subject: [PATCH 1/4] Move UnitProbeMapWidget to new widgets API --- .../widgets/_legacy_mpl_widgets/__init__.py | 3 - .../tests/test_widgets_legacy.py | 6 +- .../widgets/tests/test_widgets.py | 10 ++- .../unitprobemap.py => unit_probe_map.py} | 81 +++++++++---------- src/spikeinterface/widgets/widget_list.py | 3 + 5 files changed, 54 insertions(+), 49 deletions(-) rename src/spikeinterface/widgets/{_legacy_mpl_widgets/unitprobemap.py => unit_probe_map.py} (65%) diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py index c10c78cbfc..ff144a9943 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py @@ -7,9 +7,6 @@ # waveform/PC related from .principalcomponent import plot_principal_component -# units on probe -from .unitprobemap import plot_unit_probe_map, UnitProbeMapWidget - from .multicompgraph import ( plot_multicomp_graph, MultiCompGraphWidget, diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py index 39eb80e2e5..9aeb08698e 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py @@ -43,9 +43,9 @@ def setUp(self): def tearDown(self): pass - def test_plot_unit_probe_map(self): - sw.plot_unit_probe_map(self._we, with_channel_ids=True) - sw.plot_unit_probe_map(self._we, animated=True) + # def test_plot_unit_probe_map(self): + # sw.plot_unit_probe_map(self._we, with_channel_ids=True) + # sw.plot_unit_probe_map(self._we, animated=True) # def test_plot_units_depth_vs_amplitude(self): # sw.plot_units_depth_vs_amplitude(self._we) diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index f44878927d..f1c3456305 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -348,6 +348,13 @@ def test_plot_rasters(self): if backend not in self.skip_backends: sw.plot_rasters(self.sorting) + def test_plot_unit_probe_map(self): + possible_backends = list(sw.UnitProbeMapWidget.get_possible_backends()) + for backend in possible_backends: + if backend not in self.skip_backends: + sw.plot_unit_probe_map(self.we) + + if __name__ == "__main__": # unittest.main() @@ -372,7 +379,8 @@ def test_plot_rasters(self): # mytest.test_plot_agreement_matrix() # mytest.test_plot_confusion_matrix() # mytest.test_plot_probe_map() - mytest.test_plot_rasters() + # mytest.test_plot_rasters() + mytest.test_plot_unit_probe_map() # plt.ion() plt.show() diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/unitprobemap.py b/src/spikeinterface/widgets/unit_probe_map.py similarity index 65% rename from src/spikeinterface/widgets/_legacy_mpl_widgets/unitprobemap.py rename to src/spikeinterface/widgets/unit_probe_map.py index 6522c736ea..66b7ff3126 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/unitprobemap.py +++ b/src/spikeinterface/widgets/unit_probe_map.py @@ -1,6 +1,11 @@ import numpy as np +from typing import Union -from .basewidget import BaseWidget +# from probeinterface import ProbeGroup + +from .base import BaseWidget, to_attr +# from .utils import get_unit_colors +from ..core.waveform_extractor import WaveformExtractor class UnitProbeMapWidget(BaseWidget): @@ -21,7 +26,6 @@ class UnitProbeMapWidget(BaseWidget): with_channel_ids: bool False default add channel ids text on the probe """ - def __init__( self, waveform_extractor, @@ -30,14 +34,10 @@ def __init__( animated=None, with_channel_ids=False, colorbar=True, - ncols=5, - axes=None, + backend=None, + **backend_kwargs, ): - from matplotlib.animation import FuncAnimation - from matplotlib import pyplot as plt - from probeinterface.plotting import plot_probe - self.waveform_extractor = waveform_extractor if unit_ids is None: unit_ids = waveform_extractor.sorting.unit_ids self.unit_ids = unit_ids @@ -45,44 +45,50 @@ def __init__( channel_ids = waveform_extractor.recording.channel_ids self.channel_ids = channel_ids - self.animated = animated - self.with_channel_ids = with_channel_ids - self.colorbar = colorbar - probes = waveform_extractor.recording.get_probes() - assert len(probes) == 1, ( - "Unit probe map is only available for a single probe. If you have a probe group, " - "consider splitting the recording from different probes" + data_plot = dict( + waveform_extractor=waveform_extractor, + unit_ids=unit_ids, + channel_ids=channel_ids, + animated=animated, + with_channel_ids=with_channel_ids, + colorbar=colorbar, ) - # layout - n = len(unit_ids) - if n < ncols: - ncols = n - nrows = int(np.ceil(n / ncols)) - if axes is None: - fig, axes = plt.subplots(nrows=nrows, ncols=ncols, sharex=True, sharey=True) - BaseWidget.__init__(self, None, None, axes) - - def plot(self): - we = self.waveform_extractor + BaseWidget.__init__(self, data_plot, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + from probeinterface.plotting import plot_probe + + dp = to_attr(data_plot) + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + + # self.make_mpl_figure(**backend_kwargs) + if backend_kwargs.get("axes", None) is None: + backend_kwargs["num_axes"] = len(dp.unit_ids) + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + + we = dp.waveform_extractor probe = we.get_probe() probe_shape_kwargs = dict(facecolor="w", edgecolor="k", lw=0.5, alpha=1.0) all_poly_contact = [] - for i, unit_id in enumerate(self.unit_ids): + for i, unit_id in enumerate(dp.unit_ids): ax = self.axes.flatten()[i] template = we.get_template(unit_id) # static - if self.animated: + if dp.animated: contacts_values = np.zeros(template.shape[1]) else: contacts_values = np.max(np.abs(template), axis=0) text_on_contact = None - if self.with_channel_ids: - text_on_contact = self.channel_ids - from probeinterface.plotting import plot_probe + if dp.with_channel_ids: + text_on_contact = dp.channel_ids poly_contact, poly_contour = plot_probe( probe, @@ -96,7 +102,7 @@ def plot(self): if poly_contour is not None: poly_contour.set_zorder(1) - if self.colorbar: + if dp.colorbar: self.figure.colorbar(poly_contact, ax=ax) poly_contact.set_clim(0, np.max(np.abs(template))) @@ -104,7 +110,7 @@ def plot(self): ax.set_title(str(unit_id)) - if self.animated: + if dp.animated: num_frames = template.shape[0] def animate_func(frame): @@ -118,12 +124,3 @@ def animate_func(frame): from matplotlib.animation import FuncAnimation self.animation = FuncAnimation(self.figure, animate_func, frames=num_frames, interval=20, blit=True) - - -def plot_unit_probe_map(*args, **kwargs): - W = UnitProbeMapWidget(*args, **kwargs) - W.plot() - return W - - -plot_unit_probe_map.__doc__ = UnitProbeMapWidget.__doc__ diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index ed77de6128..525227a2e1 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -20,6 +20,7 @@ from .traces import TracesWidget from .unit_depths import UnitDepthsWidget from .unit_locations import UnitLocationsWidget +from .unit_probe_map import UnitProbeMapWidget from .unit_summary import UnitSummaryWidget from .unit_templates import UnitTemplatesWidget from .unit_waveforms_density_map import UnitWaveformDensityMapWidget @@ -46,6 +47,7 @@ TracesWidget, UnitDepthsWidget, UnitLocationsWidget, + UnitProbeMapWidget, UnitSummaryWidget, UnitTemplatesWidget, UnitWaveformDensityMapWidget, @@ -107,6 +109,7 @@ plot_traces = TracesWidget plot_unit_depths = UnitDepthsWidget plot_unit_locations = UnitLocationsWidget +plot_unit_probe_map = UnitProbeMapWidget plot_unit_summary = UnitSummaryWidget plot_unit_templates = UnitTemplatesWidget plot_unit_waveforms_density_map = UnitWaveformDensityMapWidget From e4144c589ab42ccab18e67931d39919a220e85b1 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 6 Oct 2023 14:13:29 +0200 Subject: [PATCH 2/4] Move isi distribution to new widget API. --- .../widgets/_legacy_mpl_widgets/__init__.py | 4 - .../tests/test_widgets_legacy.py | 11 +-- .../widgets/isi_distribution.py | 75 +++++++++++++++++++ src/spikeinterface/widgets/widget_list.py | 3 + 4 files changed, 82 insertions(+), 11 deletions(-) create mode 100644 src/spikeinterface/widgets/isi_distribution.py diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py index ff144a9943..061fc55339 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py @@ -1,11 +1,7 @@ -# isi/ccg/acg -from .isidistribution import plot_isi_distribution, ISIDistributionWidget # peak activity from .activity import plot_peak_activity_map, PeakActivityMapWidget -# waveform/PC related -from .principalcomponent import plot_principal_component from .multicompgraph import ( plot_multicomp_graph, diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py index 4e1bf445fc..9cd321db3c 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py @@ -58,9 +58,6 @@ def tearDown(self): # def test_amplitudes_distribution(self): # sw.plot_amplitudes_distribution(self._we) - def test_principal_component(self): - sw.plot_principal_component(self._we) - # def test_plot_unit_localization(self): # sw.plot_unit_localization(self._we, with_channel_ids=True) # sw.plot_unit_localization(self._we, method='monopolar_triangulation') @@ -73,10 +70,10 @@ def test_principal_component(self): # unit_ids = self._sorting.unit_ids[:4] # sw.plot_crosscorrelograms(self._sorting, unit_ids=unit_ids, window_ms=500.0, bin_ms=20.0) - def test_isi_distribution(self): - sw.plot_isi_distribution(self._sorting, bin_ms=5.0, window_ms=500.0) - fig, axes = plt.subplots(self.num_units, 1) - sw.plot_isi_distribution(self._sorting, axes=axes) + # def test_isi_distribution(self): + # sw.plot_isi_distribution(self._sorting, bin_ms=5.0, window_ms=500.0) + # fig, axes = plt.subplots(self.num_units, 1) + # sw.plot_isi_distribution(self._sorting, axes=axes) def test_plot_peak_activity_map(self): sw.plot_peak_activity_map(self._rec, with_channel_ids=True) diff --git a/src/spikeinterface/widgets/isi_distribution.py b/src/spikeinterface/widgets/isi_distribution.py new file mode 100644 index 0000000000..2d92d1daf7 --- /dev/null +++ b/src/spikeinterface/widgets/isi_distribution.py @@ -0,0 +1,75 @@ +import numpy as np +from warnings import warn + +from .base import BaseWidget, to_attr +from .utils import get_unit_colors + + + +class ISIDistributionWidget(BaseWidget): + """ + Plots spike train ISI distribution. + + Parameters + ---------- + sorting: SortingExtractor + The sorting extractor object + unit_ids: list + List of unit ids + bins_ms: int + Bin size in ms + window_ms: float + Window size in ms + + """ + + def __init__( + self, sorting, unit_ids=None, window_ms=100.0, bin_ms=1.0, backend=None, **backend_kwargs + ): + + if unit_ids is None: + unit_ids = sorting.get_unit_ids() + + plot_data = dict( + sorting=sorting, + unit_ids=unit_ids, + window_ms=window_ms, + bin_ms=bin_ms, + ) + + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + + dp = to_attr(data_plot) + + if backend_kwargs.get("axes", None) is None: + backend_kwargs["num_axes"] = len(dp.unit_ids) + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + sorting = dp.sorting + num_segments = sorting.get_num_segments() + fs = sorting.sampling_frequency + + for i, unit_id in enumerate(dp.unit_ids): + ax = self.axes.flatten()[i] + + bins = np.arange(0, dp.window_ms, dp.bin_ms) + bin_counts = None + for segment_index in range(num_segments): + times_ms = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) / fs * 1000. + isi = np.diff(times_ms) + + bin_counts_, bin_edges = np.histogram(isi, bins=bins, density=True) + if segment_index == 0: + bin_counts = bin_counts_ + else: + bin_counts += bin_counts_ + # TODO handle sensity when several segments + + ax.bar(x=bin_edges[:-1], height=bin_counts, width=dp.bin_ms, color="gray", align="edge") + + ax.set_ylabel(f"{unit_id}") diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index 525227a2e1..cec4b5ce53 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -8,6 +8,7 @@ from .autocorrelograms import AutoCorrelogramsWidget from .confusion_matrix import ConfusionMatrixWidget from .crosscorrelograms import CrossCorrelogramsWidget +from .isi_distribution import ISIDistributionWidget from .motion import MotionWidget from .probe_map import ProbeMapWidget from .quality_metrics import QualityMetricsWidget @@ -35,6 +36,7 @@ AutoCorrelogramsWidget, ConfusionMatrixWidget, CrossCorrelogramsWidget, + ISIDistributionWidget, MotionWidget, ProbeMapWidget, QualityMetricsWidget, @@ -97,6 +99,7 @@ plot_autocorrelograms = AutoCorrelogramsWidget plot_confusion_matrix = ConfusionMatrixWidget plot_crosscorrelograms = CrossCorrelogramsWidget +plot_isi_distribution = ISIDistributionWidget plot_motion = MotionWidget plot_probe_map = ProbeMapWidget plot_quality_metrics = QualityMetricsWidget From 86d073930c385efd9df991883dafde3a4897d2d9 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 16 Oct 2023 17:32:01 +0200 Subject: [PATCH 3/4] oups --- src/spikeinterface/widgets/tests/test_widgets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 92ef4aa6c3..4443ef7b03 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -376,7 +376,7 @@ def test_plot_unit_probe_map(self): possible_backends = list(sw.UnitProbeMapWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_unit_probe_map(self.we) + sw.plot_unit_probe_map(self.we_dense) From 4da65edd3ed0dbdc72cdb2e45c65163f3d59db55 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 16 Oct 2023 15:33:17 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../widgets/_legacy_mpl_widgets/__init__.py | 1 - src/spikeinterface/widgets/isi_distribution.py | 12 ++++-------- src/spikeinterface/widgets/tests/test_widgets.py | 2 -- src/spikeinterface/widgets/unit_probe_map.py | 5 ++--- 4 files changed, 6 insertions(+), 14 deletions(-) diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py index 061fc55339..53c2a5c79e 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py @@ -1,4 +1,3 @@ - # peak activity from .activity import plot_peak_activity_map, PeakActivityMapWidget diff --git a/src/spikeinterface/widgets/isi_distribution.py b/src/spikeinterface/widgets/isi_distribution.py index 2d92d1daf7..4256efd403 100644 --- a/src/spikeinterface/widgets/isi_distribution.py +++ b/src/spikeinterface/widgets/isi_distribution.py @@ -5,7 +5,6 @@ from .utils import get_unit_colors - class ISIDistributionWidget(BaseWidget): """ Plots spike train ISI distribution. @@ -20,13 +19,10 @@ class ISIDistributionWidget(BaseWidget): Bin size in ms window_ms: float Window size in ms - - """ - def __init__( - self, sorting, unit_ids=None, window_ms=100.0, bin_ms=1.0, backend=None, **backend_kwargs - ): + """ + def __init__(self, sorting, unit_ids=None, window_ms=100.0, bin_ms=1.0, backend=None, **backend_kwargs): if unit_ids is None: unit_ids = sorting.get_unit_ids() @@ -53,14 +49,14 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): sorting = dp.sorting num_segments = sorting.get_num_segments() fs = sorting.sampling_frequency - + for i, unit_id in enumerate(dp.unit_ids): ax = self.axes.flatten()[i] bins = np.arange(0, dp.window_ms, dp.bin_ms) bin_counts = None for segment_index in range(num_segments): - times_ms = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) / fs * 1000. + times_ms = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) / fs * 1000.0 isi = np.diff(times_ms) bin_counts_, bin_edges = np.histogram(isi, bins=bins, density=True) diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 4443ef7b03..bc3ab4272a 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -245,7 +245,6 @@ def test_isi_distribution(self): **self.backend_kwargs[backend], ) - def test_amplitudes(self): possible_backends = list(sw.AmplitudesWidget.get_possible_backends()) for backend in possible_backends: @@ -377,7 +376,6 @@ def test_plot_unit_probe_map(self): for backend in possible_backends: if backend not in self.skip_backends: sw.plot_unit_probe_map(self.we_dense) - if __name__ == "__main__": diff --git a/src/spikeinterface/widgets/unit_probe_map.py b/src/spikeinterface/widgets/unit_probe_map.py index 66b7ff3126..4068c1c530 100644 --- a/src/spikeinterface/widgets/unit_probe_map.py +++ b/src/spikeinterface/widgets/unit_probe_map.py @@ -4,6 +4,7 @@ # from probeinterface import ProbeGroup from .base import BaseWidget, to_attr + # from .utils import get_unit_colors from ..core.waveform_extractor import WaveformExtractor @@ -26,6 +27,7 @@ class UnitProbeMapWidget(BaseWidget): with_channel_ids: bool False default add channel ids text on the probe """ + def __init__( self, waveform_extractor, @@ -37,7 +39,6 @@ def __init__( backend=None, **backend_kwargs, ): - if unit_ids is None: unit_ids = waveform_extractor.sorting.unit_ids self.unit_ids = unit_ids @@ -45,7 +46,6 @@ def __init__( channel_ids = waveform_extractor.recording.channel_ids self.channel_ids = channel_ids - data_plot = dict( waveform_extractor=waveform_extractor, unit_ids=unit_ids, @@ -71,7 +71,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - we = dp.waveform_extractor probe = we.get_probe()