From 936746b4cc03eeebd13fd4fcccb55bb42007fabf Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 17 Jul 2023 18:14:22 +0200 Subject: [PATCH 01/31] Initial work to refactor widget in a unique class. --- src/spikeinterface/widgets/base.py | 89 +++++-- src/spikeinterface/widgets/unit_locations.py | 245 ++++++++++++++++++- 2 files changed, 310 insertions(+), 24 deletions(-) diff --git a/src/spikeinterface/widgets/base.py b/src/spikeinterface/widgets/base.py index 9a914bf28d..7b62dc3507 100644 --- a/src/spikeinterface/widgets/base.py +++ b/src/spikeinterface/widgets/base.py @@ -19,15 +19,63 @@ def set_default_plotter_backend(backend): default_backend_ = backend + +backend_kwargs_desc = { + "matplotlib": { + "figure": "Matplotlib figure. When None, it is created. Default None", + "ax": "Single matplotlib axis. When None, it is created. Default None", + "axes": "Multiple matplotlib axes. When None, they is created. Default None", + "ncols": "Number of columns to create in subplots. Default 5", + "figsize": "Size of matplotlib figure. Default None", + "figtitle": "The figure title. Default None", + }, + 'sortingview': { + "generate_url": "If True, the figurl URL is generated and printed. Default True", + "display": "If True and in jupyter notebook/lab, the widget is displayed in the cell. Default True.", + "figlabel": "The figurl figure label. Default None", + "height": "The height of the sortingview View in jupyter. Default None", + }, + "ipywidgets" : { + "width_cm": "Width of the figure in cm (default 10)", + "height_cm": "Height of the figure in cm (default 6)", + "display": "If True, widgets are immediately displayed", + }, + +} + +default_backend_kwargs = { + "matplotlib": {"figure": None, "ax": None, "axes": None, "ncols": 5, "figsize": None, "figtitle": None}, + "sortingview": {"generate_url": True, "display": True, "figlabel": None, "height": None}, + "ipywidgets" : {"width_cm": 25, "height_cm": 10, "display": True}, +} + + + class BaseWidget: # this need to be reset in the subclass possible_backends = None - def __init__(self, plot_data=None, backend=None, **backend_kwargs): + def __init__(self, data_plot=None, backend=None, **backend_kwargs): # every widgets must prepare a dict "plot_data" in the init - self.plot_data = plot_data + self.data_plot = data_plot self.backend = backend - self.backend_kwargs = backend_kwargs + + + for k in backend_kwargs: + if k not in default_backend_kwargs[backend]: + raise Exception( + f"{k} is not a valid plot argument or backend keyword argument. " + f"Possible backend keyword arguments for {backend} are: {list(plotter_kwargs.keys())}" + ) + backend_kwargs_ = default_backend_kwargs[backend].copy() + backend_kwargs_.update(backend_kwargs) + + self.backend_kwargs = backend_kwargs_ + + + func = getattr(self, f'plot_{backend}') + func(self) + def check_backend(self, backend): if backend is None: @@ -36,15 +84,16 @@ def check_backend(self, backend): f"{backend} backend not available! Available backends are: " f"{list(self.possible_backends.keys())}" ) return backend + - def check_backend_kwargs(self, plotter, backend, **backend_kwargs): - plotter_kwargs = plotter.default_backend_kwargs - for k in backend_kwargs: - if k not in plotter_kwargs: - raise Exception( - f"{k} is not a valid plot argument or backend keyword argument. " - f"Possible backend keyword arguments for {backend} are: {list(plotter_kwargs.keys())}" - ) + # def check_backend_kwargs(self, plotter, backend, **backend_kwargs): + # plotter_kwargs = plotter.default_backend_kwargs + # for k in backend_kwargs: + # if k not in plotter_kwargs: + # raise Exception( + # f"{k} is not a valid plot argument or backend keyword argument. " + # f"Possible backend keyword arguments for {backend} are: {list(plotter_kwargs.keys())}" + # ) def do_plot(self, backend, **backend_kwargs): backend = self.check_backend(backend) @@ -74,17 +123,17 @@ def check_extensions(waveform_extractor, extensions): raise Exception(error_msg) -class BackendPlotter: - backend = "" +# class BackendPlotter: +# backend = "" - @classmethod - def register(cls, widget_cls): - widget_cls.register_backend(cls) +# @classmethod +# def register(cls, widget_cls): +# widget_cls.register_backend(cls) - def update_backend_kwargs(self, **backend_kwargs): - backend_kwargs_ = self.default_backend_kwargs.copy() - backend_kwargs_.update(backend_kwargs) - return backend_kwargs_ +# def update_backend_kwargs(self, **backend_kwargs): +# backend_kwargs_ = self.default_backend_kwargs.copy() +# backend_kwargs_.update(backend_kwargs) +# return backend_kwargs_ def copy_signature(source_fct): diff --git a/src/spikeinterface/widgets/unit_locations.py b/src/spikeinterface/widgets/unit_locations.py index 2c58fdfe45..be9d9cacc8 100644 --- a/src/spikeinterface/widgets/unit_locations.py +++ b/src/spikeinterface/widgets/unit_locations.py @@ -1,7 +1,9 @@ import numpy as np from typing import Union -from .base import BaseWidget +from probeinterface import ProbeGroup + +from .base import BaseWidget, to_attr from .utils import get_unit_colors from ..core.waveform_extractor import WaveformExtractor @@ -31,7 +33,7 @@ class UnitLocationsWidget(BaseWidget): If True, the axis is set to off, default False (matplotlib backend) """ - possible_backends = {} + # possible_backends = {} def __init__( self, @@ -62,7 +64,7 @@ def __init__( if unit_ids is None: unit_ids = sorting.unit_ids - plot_data = dict( + data_plot = dict( all_unit_ids=sorting.unit_ids, unit_locations=unit_locations, sorting=sorting, @@ -78,4 +80,239 @@ def __init__( hide_axis=hide_axis, ) - BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + BaseWidget.__init__(self, data_plot, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, **backend_kwargs): + import matplotlib.pyplot as plt + from .matplotlib_utils import make_mpl_figure + from probeinterface.plotting import plot_probe + + from matplotlib.patches import Ellipse + from matplotlib.lines import Line2D + + + + + dp = to_attr(self.data_plot) + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + + + # self.make_mpl_figure(**backend_kwargs) + self.figure, self.axes, self.ax = make_mpl_figure(self.backend_kwargs) + + + unit_locations = dp.unit_locations + + probegroup = ProbeGroup.from_dict(dp.probegroup_dict) + probe_shape_kwargs = dict(facecolor="w", edgecolor="k", lw=0.5, alpha=1.0) + contacts_kargs = dict(alpha=1.0, edgecolor="k", lw=0.5) + + for probe in probegroup.probes: + text_on_contact = None + if dp.with_channel_ids: + text_on_contact = dp.channel_ids + + poly_contact, poly_contour = plot_probe( + probe, + ax=self.ax, + contacts_colors="w", + contacts_kargs=contacts_kargs, + probe_shape_kwargs=probe_shape_kwargs, + text_on_contact=text_on_contact, + ) + poly_contact.set_zorder(2) + if poly_contour is not None: + poly_contour.set_zorder(1) + + self.ax.set_title("") + + # color = np.array([dp.unit_colors[unit_id] for unit_id in dp.unit_ids]) + width = height = 10 + ellipse_kwargs = dict(width=width, height=height, lw=2) + + if dp.plot_all_units: + unit_colors = {} + unit_ids = dp.all_unit_ids + for unit in dp.all_unit_ids: + if unit not in dp.unit_ids: + unit_colors[unit] = "gray" + else: + unit_colors[unit] = dp.unit_colors[unit] + else: + unit_ids = dp.unit_ids + unit_colors = dp.unit_colors + labels = dp.unit_ids + + patches = [ + Ellipse( + (unit_locations[unit]), + color=unit_colors[unit], + zorder=5 if unit in dp.unit_ids else 3, + alpha=0.9 if unit in dp.unit_ids else 0.5, + **ellipse_kwargs, + ) + for i, unit in enumerate(unit_ids) + ] + for p in patches: + self.ax.add_patch(p) + handles = [ + Line2D([0], [0], ls="", marker="o", markersize=5, markeredgewidth=2, color=unit_colors[unit]) + for unit in dp.unit_ids + ] + + if dp.plot_legend: + if hasattr(self, 'legend') and self.legend is not None: + # if self.legend is not None: + self.legend.remove() + self.legend = self.figure.legend( + handles, labels, loc="upper center", bbox_to_anchor=(0.5, 1.0), ncol=5, fancybox=True, shadow=True + ) + + if dp.hide_axis: + self.ax.axis("off") + + def plot_sortingview(self): + import sortingview.views as vv + from .sortingview_utils import generate_unit_table_view, make_serializable, handle_display_and_url + + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + backend_kwargs = self.backend_kwargs + dp = to_attr(self.data_plot) + + # ensure serializable for sortingview + unit_ids, channel_ids = make_serializable(dp.unit_ids, dp.channel_ids) + + locations = {str(ch): dp.channel_locations[i_ch].astype("float32") for i_ch, ch in enumerate(channel_ids)} + + unit_items = [] + for unit_id in unit_ids: + unit_items.append( + vv.UnitLocationsItem( + unit_id=unit_id, x=float(dp.unit_locations[unit_id][0]), y=float(dp.unit_locations[unit_id][1]) + ) + ) + + v_unit_locations = vv.UnitLocations(units=unit_items, channel_locations=locations, disable_auto_rotate=True) + + if not dp.hide_unit_selector: + v_units_table = generate_unit_table_view(dp.sorting) + + self.view = vv.Box( + direction="horizontal", + items=[vv.LayoutItem(v_units_table, max_size=150), vv.LayoutItem(v_unit_locations)], + ) + else: + self.view = v_unit_locations + + # self.handle_display_and_url(view, **backend_kwargs) + self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) + + + def plot_ipywidgets(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + import ipywidgets.widgets as widgets + from IPython.display import display + + from .ipywidgets_utils import check_ipywidget_backend, make_unit_controller + + cm = 1 / 2.54 + + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + + + width_cm = backend_kwargs["width_cm"] + height_cm = backend_kwargs["height_cm"] + + ratios = [0.15, 0.85] + + with plt.ioff(): + output = widgets.Output() + with output: + fig, ax = plt.subplots(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) + plt.show() + + data_plot["unit_ids"] = data_plot["unit_ids"][:1] + unit_widget, unit_controller = make_unit_controller( + data_plot["unit_ids"], list(data_plot["unit_colors"].keys()), ratios[0] * width_cm, height_cm + ) + + self.controller = unit_controller + + # mpl_plotter = MplUnitLocationsPlotter() + + # self.updater = PlotUpdater(data_plot, mpl_plotter, ax, self.controller) + # for w in self.controller.values(): + # w.observe(self.updater) + + for w in self.controller.values(): + w.observe(self.update_widget) + + self.widget = widgets.AppLayout( + center=fig.canvas, + left_sidebar=unit_widget, + pane_widths=ratios + [0], + ) + + # a first update + self.updater(None) + + if backend_kwargs["display"]: + self.check_backend() + display(self.widget) + + def update_widget(self, change): + self.ax.clear() + + unit_ids = self.controller["unit_ids"].value + + # matplotlib next_data_plot dict update at each call + # data_plot = self.next_data_plot + self.data_plot["unit_ids"] = unit_ids + self.data_plot["plot_all_units"] = True + self.data_plot["plot_legend"] = True + self.data_plot["hide_axis"] = True + + backend_kwargs = {} + backend_kwargs["ax"] = self.ax + + # self.mpl_plotter.do_plot(data_plot, **backend_kwargs) + self.plot_matplotlib(data_plot, **backend_kwargs) + fig = self.ax.get_figure() + fig.canvas.draw() + fig.canvas.flush_events() + + + + + +class PlotUpdater: + def __init__(self, data_plot, mpl_plotter, ax, controller): + self.data_plot = data_plot + self.mpl_plotter = mpl_plotter + self.ax = ax + self.controller = controller + + self.next_data_plot = data_plot.copy() + + def __call__(self, change): + self.ax.clear() + + unit_ids = self.controller["unit_ids"].value + + # matplotlib next_data_plot dict update at each call + data_plot = self.next_data_plot + data_plot["unit_ids"] = unit_ids + data_plot["plot_all_units"] = True + data_plot["plot_legend"] = True + data_plot["hide_axis"] = True + + backend_kwargs = {} + backend_kwargs["ax"] = self.ax + + # self.mpl_plotter.do_plot(data_plot, **backend_kwargs) + UnitLocationsWidget.plot_matplotlib(data_plot, **backend_kwargs) + fig = self.ax.get_figure() + fig.canvas.draw() + fig.canvas.flush_events() + + From bccc462a0c89b588df7c48a8f84b18bd00f24dfc Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 18 Jul 2023 11:38:21 +0200 Subject: [PATCH 02/31] refactor wip --- src/spikeinterface/widgets/base.py | 22 ++++-- src/spikeinterface/widgets/unit_locations.py | 52 +++---------- src/spikeinterface/widgets/widget_list.py | 81 ++++++++++---------- 3 files changed, 66 insertions(+), 89 deletions(-) diff --git a/src/spikeinterface/widgets/base.py b/src/spikeinterface/widgets/base.py index 7b62dc3507..3b708c57d7 100644 --- a/src/spikeinterface/widgets/base.py +++ b/src/spikeinterface/widgets/base.py @@ -55,7 +55,7 @@ class BaseWidget: # this need to be reset in the subclass possible_backends = None - def __init__(self, data_plot=None, backend=None, **backend_kwargs): + def __init__(self, data_plot=None, backend=None, **backend_kwargs, do_plot=True): # every widgets must prepare a dict "plot_data" in the init self.data_plot = data_plot self.backend = backend @@ -72,9 +72,11 @@ def __init__(self, data_plot=None, backend=None, **backend_kwargs): self.backend_kwargs = backend_kwargs_ + if do_plot: + self.do_plot() + + - func = getattr(self, f'plot_{backend}') - func(self) def check_backend(self, backend): @@ -96,11 +98,15 @@ def check_backend(self, backend): # ) def do_plot(self, backend, **backend_kwargs): - backend = self.check_backend(backend) - plotter = self.possible_backends[backend]() - self.check_backend_kwargs(plotter, backend, **backend_kwargs) - plotter.do_plot(self.plot_data, **backend_kwargs) - self.plotter = plotter + # backend = self.check_backend(backend) + # plotter = self.possible_backends[backend]() + # self.check_backend_kwargs(plotter, backend, **backend_kwargs) + # plotter.do_plot(self.plot_data, **backend_kwargs) + # self.plotter = plotter + + func = getattr(self, f'plot_{backend}') + func(self, self.data_plot, self.backend_kwargs) + @classmethod def register_backend(cls, backend_plotter): diff --git a/src/spikeinterface/widgets/unit_locations.py b/src/spikeinterface/widgets/unit_locations.py index be9d9cacc8..4ea306bad6 100644 --- a/src/spikeinterface/widgets/unit_locations.py +++ b/src/spikeinterface/widgets/unit_locations.py @@ -82,7 +82,7 @@ def __init__( BaseWidget.__init__(self, data_plot, backend=backend, **backend_kwargs) - def plot_matplotlib(self, **backend_kwargs): + def plot_matplotlib(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt from .matplotlib_utils import make_mpl_figure from probeinterface.plotting import plot_probe @@ -93,12 +93,12 @@ def plot_matplotlib(self, **backend_kwargs): - dp = to_attr(self.data_plot) + dp = to_attr(data_plot) # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) # self.make_mpl_figure(**backend_kwargs) - self.figure, self.axes, self.ax = make_mpl_figure(self.backend_kwargs) + self.figure, self.axes, self.ax = make_mpl_figure(backend_kwargs) unit_locations = dp.unit_locations @@ -171,13 +171,12 @@ def plot_matplotlib(self, **backend_kwargs): if dp.hide_axis: self.ax.axis("off") - def plot_sortingview(self): + def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv from .sortingview_utils import generate_unit_table_view, make_serializable, handle_display_and_url # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - backend_kwargs = self.backend_kwargs - dp = to_attr(self.data_plot) + dp = to_attr(data_plot) # ensure serializable for sortingview unit_ids, channel_ids = make_serializable(dp.unit_ids, dp.channel_ids) @@ -215,6 +214,8 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): from .ipywidgets_utils import check_ipywidget_backend, make_unit_controller + self.next_data_plot = data_plot.copy() + cm = 1 / 2.54 # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) @@ -228,7 +229,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): with plt.ioff(): output = widgets.Output() with output: - fig, ax = plt.subplots(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) + fig, self.ax = plt.subplots(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) plt.show() data_plot["unit_ids"] = data_plot["unit_ids"][:1] @@ -265,40 +266,6 @@ def update_widget(self, change): unit_ids = self.controller["unit_ids"].value - # matplotlib next_data_plot dict update at each call - # data_plot = self.next_data_plot - self.data_plot["unit_ids"] = unit_ids - self.data_plot["plot_all_units"] = True - self.data_plot["plot_legend"] = True - self.data_plot["hide_axis"] = True - - backend_kwargs = {} - backend_kwargs["ax"] = self.ax - - # self.mpl_plotter.do_plot(data_plot, **backend_kwargs) - self.plot_matplotlib(data_plot, **backend_kwargs) - fig = self.ax.get_figure() - fig.canvas.draw() - fig.canvas.flush_events() - - - - - -class PlotUpdater: - def __init__(self, data_plot, mpl_plotter, ax, controller): - self.data_plot = data_plot - self.mpl_plotter = mpl_plotter - self.ax = ax - self.controller = controller - - self.next_data_plot = data_plot.copy() - - def __call__(self, change): - self.ax.clear() - - unit_ids = self.controller["unit_ids"].value - # matplotlib next_data_plot dict update at each call data_plot = self.next_data_plot data_plot["unit_ids"] = unit_ids @@ -310,9 +277,10 @@ def __call__(self, change): backend_kwargs["ax"] = self.ax # self.mpl_plotter.do_plot(data_plot, **backend_kwargs) - UnitLocationsWidget.plot_matplotlib(data_plot, **backend_kwargs) + self.plot_matplotlib(data_plot, **backend_kwargs) fig = self.ax.get_figure() fig.canvas.draw() fig.canvas.flush_events() + diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index a6e0896e99..4dbd4b3c68 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -56,25 +56,25 @@ widget_list = [ - AmplitudesWidget, - AllAmplitudesDistributionsWidget, - AutoCorrelogramsWidget, - CrossCorrelogramsWidget, - QualityMetricsWidget, - SpikeLocationsWidget, - SpikesOnTracesWidget, - TemplateMetricsWidget, - MotionWidget, - TemplateSimilarityWidget, - TimeseriesWidget, + # AmplitudesWidget, + # AllAmplitudesDistributionsWidget, + # AutoCorrelogramsWidget, + # CrossCorrelogramsWidget, + # QualityMetricsWidget, + # SpikeLocationsWidget, + # SpikesOnTracesWidget, + # TemplateMetricsWidget, + # MotionWidget, + # TemplateSimilarityWidget, + # TimeseriesWidget, UnitLocationsWidget, - UnitTemplatesWidget, - UnitWaveformsWidget, - UnitWaveformDensityMapWidget, - UnitDepthsWidget, + # UnitTemplatesWidget, + # UnitWaveformsWidget, + # UnitWaveformDensityMapWidget, + # UnitDepthsWidget, # summary - UnitSummaryWidget, - SortingSummaryWidget, + # UnitSummaryWidget, + # SortingSummaryWidget, ] @@ -101,25 +101,28 @@ # make function for all widgets -plot_amplitudes = define_widget_function_from_class(AmplitudesWidget, "plot_amplitudes") -plot_all_amplitudes_distributions = define_widget_function_from_class( - AllAmplitudesDistributionsWidget, "plot_all_amplitudes_distributions" -) -plot_autocorrelograms = define_widget_function_from_class(AutoCorrelogramsWidget, "plot_autocorrelograms") -plot_crosscorrelograms = define_widget_function_from_class(CrossCorrelogramsWidget, "plot_crosscorrelograms") -plot_quality_metrics = define_widget_function_from_class(QualityMetricsWidget, "plot_quality_metrics") -plot_spike_locations = define_widget_function_from_class(SpikeLocationsWidget, "plot_spike_locations") -plot_spikes_on_traces = define_widget_function_from_class(SpikesOnTracesWidget, "plot_spikes_on_traces") -plot_template_metrics = define_widget_function_from_class(TemplateMetricsWidget, "plot_template_metrics") -plot_motion = define_widget_function_from_class(MotionWidget, "plot_motion") -plot_template_similarity = define_widget_function_from_class(TemplateSimilarityWidget, "plot_template_similarity") -plot_timeseries = define_widget_function_from_class(TimeseriesWidget, "plot_timeseries") -plot_unit_locations = define_widget_function_from_class(UnitLocationsWidget, "plot_unit_locations") -plot_unit_templates = define_widget_function_from_class(UnitTemplatesWidget, "plot_unit_templates") -plot_unit_waveforms = define_widget_function_from_class(UnitWaveformsWidget, "plot_unit_waveforms") -plot_unit_waveforms_density_map = define_widget_function_from_class( - UnitWaveformDensityMapWidget, "plot_unit_waveforms_density_map" -) -plot_unit_depths = define_widget_function_from_class(UnitDepthsWidget, "plot_unit_depths") -plot_unit_summary = define_widget_function_from_class(UnitSummaryWidget, "plot_unit_summary") -plot_sorting_summary = define_widget_function_from_class(SortingSummaryWidget, "plot_sorting_summary") +# plot_amplitudes = define_widget_function_from_class(AmplitudesWidget, "plot_amplitudes") +# plot_all_amplitudes_distributions = define_widget_function_from_class( +# AllAmplitudesDistributionsWidget, "plot_all_amplitudes_distributions" +# ) +# plot_autocorrelograms = define_widget_function_from_class(AutoCorrelogramsWidget, "plot_autocorrelograms") +# plot_crosscorrelograms = define_widget_function_from_class(CrossCorrelogramsWidget, "plot_crosscorrelograms") +# plot_quality_metrics = define_widget_function_from_class(QualityMetricsWidget, "plot_quality_metrics") +# plot_spike_locations = define_widget_function_from_class(SpikeLocationsWidget, "plot_spike_locations") +# plot_spikes_on_traces = define_widget_function_from_class(SpikesOnTracesWidget, "plot_spikes_on_traces") +# plot_template_metrics = define_widget_function_from_class(TemplateMetricsWidget, "plot_template_metrics") +# plot_motion = define_widget_function_from_class(MotionWidget, "plot_motion") +# plot_template_similarity = define_widget_function_from_class(TemplateSimilarityWidget, "plot_template_similarity") +# plot_timeseries = define_widget_function_from_class(TimeseriesWidget, "plot_timeseries") +# plot_unit_locations = define_widget_function_from_class(UnitLocationsWidget, "plot_unit_locations") +# plot_unit_templates = define_widget_function_from_class(UnitTemplatesWidget, "plot_unit_templates") +# plot_unit_waveforms = define_widget_function_from_class(UnitWaveformsWidget, "plot_unit_waveforms") +# plot_unit_waveforms_density_map = define_widget_function_from_class( +# UnitWaveformDensityMapWidget, "plot_unit_waveforms_density_map" +# ) +# plot_unit_depths = define_widget_function_from_class(UnitDepthsWidget, "plot_unit_depths") +# plot_unit_summary = define_widget_function_from_class(UnitSummaryWidget, "plot_unit_summary") +# plot_sorting_summary = define_widget_function_from_class(SortingSummaryWidget, "plot_sorting_summary") + + +plot_unit_locations = UnitLocationsWidget From 750ad2495c7d049d7da1c4e065743c43a467ddc8 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 18 Jul 2023 15:02:45 +0200 Subject: [PATCH 03/31] widget wip --- src/spikeinterface/widgets/__init__.py | 44 ++++++------- src/spikeinterface/widgets/base.py | 61 +++++++++---------- .../widgets/tests/test_widgets.py | 49 ++++++++------- src/spikeinterface/widgets/unit_locations.py | 3 +- src/spikeinterface/widgets/widget_list.py | 50 ++++++++------- 5 files changed, 106 insertions(+), 101 deletions(-) diff --git a/src/spikeinterface/widgets/__init__.py b/src/spikeinterface/widgets/__init__.py index 83f4b85fee..bb779ff7fb 100644 --- a/src/spikeinterface/widgets/__init__.py +++ b/src/spikeinterface/widgets/__init__.py @@ -1,35 +1,35 @@ # check if backend are available -try: - import matplotlib +# try: +# import matplotlib - HAVE_MPL = True -except: - HAVE_MPL = False +# HAVE_MPL = True +# except: +# HAVE_MPL = False -try: - import sortingview +# try: +# import sortingview - HAVE_SV = True -except: - HAVE_SV = False +# HAVE_SV = True +# except: +# HAVE_SV = False -try: - import ipywidgets +# try: +# import ipywidgets - HAVE_IPYW = True -except: - HAVE_IPYW = False +# HAVE_IPYW = True +# except: +# HAVE_IPYW = False -# theses import make the Widget.resgister() at import time -if HAVE_MPL: - import spikeinterface.widgets.matplotlib +# # theses import make the Widget.resgister() at import time +# if HAVE_MPL: +# import spikeinterface.widgets.matplotlib -if HAVE_SV: - import spikeinterface.widgets.sortingview +# if HAVE_SV: +# import spikeinterface.widgets.sortingview -if HAVE_IPYW: - import spikeinterface.widgets.ipywidgets +# if HAVE_IPYW: +# import spikeinterface.widgets.ipywidgets # when importing widget list backend are already registered from .widget_list import * diff --git a/src/spikeinterface/widgets/base.py b/src/spikeinterface/widgets/base.py index 3b708c57d7..17903b495b 100644 --- a/src/spikeinterface/widgets/base.py +++ b/src/spikeinterface/widgets/base.py @@ -55,12 +55,12 @@ class BaseWidget: # this need to be reset in the subclass possible_backends = None - def __init__(self, data_plot=None, backend=None, **backend_kwargs, do_plot=True): + def __init__(self, data_plot=None, backend=None, immediate_plot=True, **backend_kwargs, ): # every widgets must prepare a dict "plot_data" in the init self.data_plot = data_plot - self.backend = backend - + self.backend = self.check_backend(backend) + # check backend kwargs for k in backend_kwargs: if k not in default_backend_kwargs[backend]: raise Exception( @@ -72,18 +72,18 @@ def __init__(self, data_plot=None, backend=None, **backend_kwargs, do_plot=True) self.backend_kwargs = backend_kwargs_ - if do_plot: - self.do_plot() - + if immediate_plot: + self.do_plot(self.backend, **self.backend_kwargs) - - + @classmethod + def get_possible_backends(cls): + return [ k for k in default_backend_kwargs if hasattr(cls, f"plot_{k}") ] def check_backend(self, backend): if backend is None: backend = get_default_plotter_backend() - assert backend in self.possible_backends, ( - f"{backend} backend not available! Available backends are: " f"{list(self.possible_backends.keys())}" + assert backend in self.get_possible_backends(), ( + f"{backend} backend not available! Available backends are: " f"{self.get_possible_backends()}" ) return backend @@ -99,18 +99,13 @@ def check_backend(self, backend): def do_plot(self, backend, **backend_kwargs): # backend = self.check_backend(backend) - # plotter = self.possible_backends[backend]() - # self.check_backend_kwargs(plotter, backend, **backend_kwargs) - # plotter.do_plot(self.plot_data, **backend_kwargs) - # self.plotter = plotter func = getattr(self, f'plot_{backend}') - func(self, self.data_plot, self.backend_kwargs) - + func(data_plot=self.data_plot, **self.backend_kwargs) - @classmethod - def register_backend(cls, backend_plotter): - cls.possible_backends[backend_plotter.backend] = backend_plotter + # @classmethod + # def register_backend(cls, backend_plotter): + # cls.possible_backends[backend_plotter.backend] = backend_plotter @staticmethod def check_extensions(waveform_extractor, extensions): @@ -142,12 +137,12 @@ def check_extensions(waveform_extractor, extensions): # return backend_kwargs_ -def copy_signature(source_fct): - def copy(target_fct): - target_fct.__signature__ = inspect.signature(source_fct) - return target_fct +# def copy_signature(source_fct): +# def copy(target_fct): +# target_fct.__signature__ = inspect.signature(source_fct) +# return target_fct - return copy +# return copy class to_attr(object): @@ -168,14 +163,14 @@ def __getattribute__(self, k): return d[k] -def define_widget_function_from_class(widget_class, name): - @copy_signature(widget_class) - def widget_func(*args, **kwargs): - W = widget_class(*args, **kwargs) - W.do_plot(W.backend, **W.backend_kwargs) - return W.plotter +# def define_widget_function_from_class(widget_class, name): +# @copy_signature(widget_class) +# def widget_func(*args, **kwargs): +# W = widget_class(*args, **kwargs) +# W.do_plot(W.backend, **W.backend_kwargs) +# return W.plotter - widget_func.__doc__ = widget_class.__doc__ - widget_func.__name__ = name +# widget_func.__doc__ = widget_class.__doc__ +# widget_func.__name__ = name - return widget_func +# return widget_func diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 3a60a9d2c7..cb4341f044 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -13,7 +13,8 @@ from spikeinterface import extract_waveforms, load_waveforms, download_dataset, compute_sparsity -from spikeinterface.widgets import HAVE_MPL, HAVE_SV +# from spikeinterface.widgets import HAVE_MPL, HAVE_SV + import spikeinterface.extractors as se import spikeinterface.widgets as sw @@ -68,7 +69,10 @@ def setUpClass(cls): # make sparse waveforms cls.sparsity_radius = compute_sparsity(cls.we, method="radius", radius_um=50) cls.sparsity_best = compute_sparsity(cls.we, method="best_channels", num_channels=5) - cls.we_sparse = cls.we.save(folder=cache_folder / "mearec_test_sparse", sparsity=cls.sparsity_radius) + if (cache_folder / "mearec_test_sparse").is_dir(): + cls.we_sparse = load_waveforms(cache_folder / "mearec_test_sparse") + else: + cls.we_sparse = cls.we.save(folder=cache_folder / "mearec_test_sparse", sparsity=cls.sparsity_radius) cls.skip_backends = ["ipywidgets"] @@ -82,7 +86,7 @@ def setUpClass(cls): cls.gt_comp = sc.compare_sorter_to_ground_truth(cls.sorting, cls.sorting) def test_plot_timeseries(self): - possible_backends = list(sw.TimeseriesWidget.possible_backends.keys()) + possible_backends = list(sw.TimeseriesWidget.get_possible_backends()) for backend in possible_backends: if ON_GITHUB and backend == "sortingview": continue @@ -119,7 +123,7 @@ def test_plot_timeseries(self): ) def test_plot_unit_waveforms(self): - possible_backends = list(sw.UnitWaveformsWidget.possible_backends.keys()) + possible_backends = list(sw.UnitWaveformsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: sw.plot_unit_waveforms(self.we, backend=backend, **self.backend_kwargs[backend]) @@ -143,7 +147,7 @@ def test_plot_unit_waveforms(self): ) def test_plot_unit_templates(self): - possible_backends = list(sw.UnitWaveformsWidget.possible_backends.keys()) + possible_backends = list(sw.UnitWaveformsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: sw.plot_unit_templates(self.we, backend=backend, **self.backend_kwargs[backend]) @@ -164,7 +168,7 @@ def test_plot_unit_templates(self): ) def test_plot_unit_waveforms_density_map(self): - possible_backends = list(sw.UnitWaveformDensityMapWidget.possible_backends.keys()) + possible_backends = list(sw.UnitWaveformDensityMapWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: unit_ids = self.sorting.unit_ids[:2] @@ -173,7 +177,7 @@ def test_plot_unit_waveforms_density_map(self): ) def test_plot_unit_waveforms_density_map_sparsity_radius(self): - possible_backends = list(sw.UnitWaveformDensityMapWidget.possible_backends.keys()) + possible_backends = list(sw.UnitWaveformDensityMapWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: unit_ids = self.sorting.unit_ids[:2] @@ -187,7 +191,7 @@ def test_plot_unit_waveforms_density_map_sparsity_radius(self): ) def test_plot_unit_waveforms_density_map_sparsity_None_same_axis(self): - possible_backends = list(sw.UnitWaveformDensityMapWidget.possible_backends.keys()) + possible_backends = list(sw.UnitWaveformDensityMapWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: unit_ids = self.sorting.unit_ids[:2] @@ -201,7 +205,7 @@ def test_plot_unit_waveforms_density_map_sparsity_None_same_axis(self): ) def test_autocorrelograms(self): - possible_backends = list(sw.AutoCorrelogramsWidget.possible_backends.keys()) + possible_backends = list(sw.AutoCorrelogramsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: unit_ids = self.sorting.unit_ids[:4] @@ -215,7 +219,7 @@ def test_autocorrelograms(self): ) def test_crosscorrelogram(self): - possible_backends = list(sw.CrossCorrelogramsWidget.possible_backends.keys()) + possible_backends = list(sw.CrossCorrelogramsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: unit_ids = self.sorting.unit_ids[:4] @@ -229,7 +233,7 @@ def test_crosscorrelogram(self): ) def test_amplitudes(self): - possible_backends = list(sw.AmplitudesWidget.possible_backends.keys()) + possible_backends = list(sw.AmplitudesWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: sw.plot_amplitudes(self.we, backend=backend, **self.backend_kwargs[backend]) @@ -247,7 +251,7 @@ def test_amplitudes(self): ) def test_plot_all_amplitudes_distributions(self): - possible_backends = list(sw.AllAmplitudesDistributionsWidget.possible_backends.keys()) + possible_backends = list(sw.AllAmplitudesDistributionsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: unit_ids = self.we.unit_ids[:4] @@ -259,7 +263,7 @@ def test_plot_all_amplitudes_distributions(self): ) def test_unit_locations(self): - possible_backends = list(sw.UnitLocationsWidget.possible_backends.keys()) + possible_backends = list(sw.UnitLocationsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: sw.plot_unit_locations(self.we, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend]) @@ -268,7 +272,7 @@ def test_unit_locations(self): ) def test_spike_locations(self): - possible_backends = list(sw.SpikeLocationsWidget.possible_backends.keys()) + possible_backends = list(sw.SpikeLocationsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: sw.plot_spike_locations(self.we, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend]) @@ -277,35 +281,35 @@ def test_spike_locations(self): ) def test_similarity(self): - possible_backends = list(sw.TemplateSimilarityWidget.possible_backends.keys()) + possible_backends = list(sw.TemplateSimilarityWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: sw.plot_template_similarity(self.we, backend=backend, **self.backend_kwargs[backend]) sw.plot_template_similarity(self.we_sparse, backend=backend, **self.backend_kwargs[backend]) def test_quality_metrics(self): - possible_backends = list(sw.QualityMetricsWidget.possible_backends.keys()) + possible_backends = list(sw.QualityMetricsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: sw.plot_quality_metrics(self.we, backend=backend, **self.backend_kwargs[backend]) sw.plot_quality_metrics(self.we_sparse, backend=backend, **self.backend_kwargs[backend]) def test_template_metrics(self): - possible_backends = list(sw.TemplateMetricsWidget.possible_backends.keys()) + possible_backends = list(sw.TemplateMetricsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: sw.plot_template_metrics(self.we, backend=backend, **self.backend_kwargs[backend]) sw.plot_template_metrics(self.we_sparse, backend=backend, **self.backend_kwargs[backend]) def test_plot_unit_depths(self): - possible_backends = list(sw.UnitDepthsWidget.possible_backends.keys()) + possible_backends = list(sw.UnitDepthsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: sw.plot_unit_depths(self.we, backend=backend, **self.backend_kwargs[backend]) sw.plot_unit_depths(self.we_sparse, backend=backend, **self.backend_kwargs[backend]) def test_plot_unit_summary(self): - possible_backends = list(sw.UnitSummaryWidget.possible_backends.keys()) + possible_backends = list(sw.UnitSummaryWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: sw.plot_unit_summary( @@ -316,7 +320,7 @@ def test_plot_unit_summary(self): ) def test_sorting_summary(self): - possible_backends = list(sw.SortingSummaryWidget.possible_backends.keys()) + possible_backends = list(sw.SortingSummaryWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: sw.plot_sorting_summary(self.we, backend=backend, **self.backend_kwargs[backend]) @@ -339,8 +343,9 @@ def test_sorting_summary(self): # mytest.test_plot_unit_depths() # mytest.test_plot_unit_templates() # mytest.test_plot_unit_summary() - mytest.test_quality_metrics() - mytest.test_template_metrics() + mytest.test_unit_locations() + # mytest.test_quality_metrics() + # mytest.test_template_metrics() # plt.ion() plt.show() diff --git a/src/spikeinterface/widgets/unit_locations.py b/src/spikeinterface/widgets/unit_locations.py index 4ea306bad6..036158cda7 100644 --- a/src/spikeinterface/widgets/unit_locations.py +++ b/src/spikeinterface/widgets/unit_locations.py @@ -79,10 +79,11 @@ def __init__( plot_legend=plot_legend, hide_axis=hide_axis, ) - + BaseWidget.__init__(self, data_plot, backend=backend, **backend_kwargs) def plot_matplotlib(self, data_plot, **backend_kwargs): + print(data_plot, backend_kwargs) import matplotlib.pyplot as plt from .matplotlib_utils import make_mpl_figure from probeinterface.plotting import plot_probe diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index 4dbd4b3c68..53f2e7eb62 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -1,29 +1,30 @@ -from .base import define_widget_function_from_class +# from .base import define_widget_function_from_class +from .base import backend_kwargs_desc # basics -from .timeseries import TimeseriesWidget +# from .timeseries import TimeseriesWidget # waveform -from .unit_waveforms import UnitWaveformsWidget -from .unit_templates import UnitTemplatesWidget -from .unit_waveforms_density_map import UnitWaveformDensityMapWidget +# from .unit_waveforms import UnitWaveformsWidget +# from .unit_templates import UnitTemplatesWidget +# from .unit_waveforms_density_map import UnitWaveformDensityMapWidget # isi/ccg/acg -from .autocorrelograms import AutoCorrelogramsWidget -from .crosscorrelograms import CrossCorrelogramsWidget +# from .autocorrelograms import AutoCorrelogramsWidget +# from .crosscorrelograms import CrossCorrelogramsWidget # peak activity # drift/motion # spikes-traces -from .spikes_on_traces import SpikesOnTracesWidget +# from .spikes_on_traces import SpikesOnTracesWidget # PC related # units on probe from .unit_locations import UnitLocationsWidget -from .spike_locations import SpikeLocationsWidget +# from .spike_locations import SpikeLocationsWidget # unit presence @@ -33,26 +34,26 @@ # correlogram comparison # amplitudes -from .amplitudes import AmplitudesWidget -from .all_amplitudes_distributions import AllAmplitudesDistributionsWidget +# from .amplitudes import AmplitudesWidget +# from .all_amplitudes_distributions import AllAmplitudesDistributionsWidget # metrics -from .quality_metrics import QualityMetricsWidget -from .template_metrics import TemplateMetricsWidget +# from .quality_metrics import QualityMetricsWidget +# from .template_metrics import TemplateMetricsWidget # motion/drift -from .motion import MotionWidget +# from .motion import MotionWidget # similarity -from .template_similarity import TemplateSimilarityWidget +# from .template_similarity import TemplateSimilarityWidget -from .unit_depths import UnitDepthsWidget +# from .unit_depths import UnitDepthsWidget # summary -from .unit_summary import UnitSummaryWidget -from .sorting_summary import SortingSummaryWidget +# from .unit_summary import UnitSummaryWidget +# from .sorting_summary import SortingSummaryWidget widget_list = [ @@ -89,13 +90,16 @@ **backend_kwargs: kwargs {backend_kwargs} """ - backend_str = f" {list(wcls.possible_backends.keys())}" + # backend_str = f" {list(wcls.possible_backends.keys())}" + backend_str = f" {wcls.get_possible_backends()}" backend_kwargs_str = "" - for backend, backend_plotter in wcls.possible_backends.items(): - backend_kwargs_desc = backend_plotter.backend_kwargs_desc - if len(backend_kwargs_desc) > 0: + # for backend, backend_plotter in wcls.possible_backends.items(): + for backend in wcls.get_possible_backends(): + # backend_kwargs_desc = backend_plotter.backend_kwargs_desc + kwargs_desc = backend_kwargs_desc[backend] + if len(kwargs_desc) > 0: backend_kwargs_str += f"\n {backend}:\n\n" - for bk, bk_dsc in backend_kwargs_desc.items(): + for bk, bk_dsc in kwargs_desc.items(): backend_kwargs_str += f" * {bk}: {bk_dsc}\n" wcls.__doc__ = wcls_doc.format(backends=backend_str, backend_kwargs=backend_kwargs_str) From 3f236c703e830c931223bab7fdda5fa0de84cd59 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 19 Jul 2023 09:39:29 +0200 Subject: [PATCH 04/31] widgets utils files --- .../widgets/ipywidgets_utils.py | 105 ++++++++++++++++++ .../widgets/matplotlib_utils.py | 75 +++++++++++++ .../widgets/sortingview_utils.py | 95 ++++++++++++++++ 3 files changed, 275 insertions(+) create mode 100644 src/spikeinterface/widgets/ipywidgets_utils.py create mode 100644 src/spikeinterface/widgets/matplotlib_utils.py create mode 100644 src/spikeinterface/widgets/sortingview_utils.py diff --git a/src/spikeinterface/widgets/ipywidgets_utils.py b/src/spikeinterface/widgets/ipywidgets_utils.py new file mode 100644 index 0000000000..4490cc3365 --- /dev/null +++ b/src/spikeinterface/widgets/ipywidgets_utils.py @@ -0,0 +1,105 @@ +import ipywidgets.widgets as widgets +import numpy as np + + + +def check_ipywidget_backend(): + import matplotlib + mpl_backend = matplotlib.get_backend() + assert "ipympl" in mpl_backend, "To use the 'ipywidgets' backend, you have to set %matplotlib widget" + + + +def make_timeseries_controller(t_start, t_stop, layer_keys, num_segments, time_range, mode, all_layers, width_cm): + time_slider = widgets.FloatSlider( + orientation="horizontal", + description="time:", + value=time_range[0], + min=t_start, + max=t_stop, + continuous_update=False, + layout=widgets.Layout(width=f"{width_cm}cm"), + ) + layer_selector = widgets.Dropdown(description="layer", options=layer_keys) + segment_selector = widgets.Dropdown(description="segment", options=list(range(num_segments))) + window_sizer = widgets.BoundedFloatText(value=np.diff(time_range)[0], step=0.1, min=0.005, description="win (s)") + mode_selector = widgets.Dropdown(options=["line", "map"], description="mode", value=mode) + all_layers = widgets.Checkbox(description="plot all layers", value=all_layers) + + controller = { + "layer_key": layer_selector, + "segment_index": segment_selector, + "window": window_sizer, + "t_start": time_slider, + "mode": mode_selector, + "all_layers": all_layers, + } + widget = widgets.VBox( + [time_slider, widgets.HBox([all_layers, layer_selector, segment_selector, window_sizer, mode_selector])] + ) + + return widget, controller + + +def make_unit_controller(unit_ids, all_unit_ids, width_cm, height_cm): + unit_label = widgets.Label(value="units:") + + unit_selector = widgets.SelectMultiple( + options=all_unit_ids, + value=list(unit_ids), + disabled=False, + layout=widgets.Layout(width=f"{width_cm}cm", height=f"{height_cm}cm"), + ) + + controller = {"unit_ids": unit_selector} + widget = widgets.VBox([unit_label, unit_selector]) + + return widget, controller + + +def make_channel_controller(recording, width_cm, height_cm): + channel_label = widgets.Label("channel indices:", layout=widgets.Layout(justify_content="center")) + channel_selector = widgets.IntRangeSlider( + value=[0, recording.get_num_channels()], + min=0, + max=recording.get_num_channels(), + step=1, + disabled=False, + continuous_update=False, + orientation="vertical", + readout=True, + readout_format="d", + layout=widgets.Layout(width=f"{0.8 * width_cm}cm", height=f"{height_cm}cm"), + ) + + controller = {"channel_inds": channel_selector} + widget = widgets.VBox([channel_label, channel_selector]) + + return widget, controller + + +def make_scale_controller(width_cm, height_cm): + scale_label = widgets.Label("Scale", layout=widgets.Layout(justify_content="center")) + + plus_selector = widgets.Button( + description="", + disabled=False, + button_style="", # 'success', 'info', 'warning', 'danger' or '' + tooltip="Increase scale", + icon="arrow-up", + layout=widgets.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), + ) + + minus_selector = widgets.Button( + description="", + disabled=False, + button_style="", # 'success', 'info', 'warning', 'danger' or '' + tooltip="Decrease scale", + icon="arrow-down", + layout=widgets.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), + ) + + controller = {"plus": plus_selector, "minus": minus_selector} + widget = widgets.VBox([scale_label, plus_selector, minus_selector]) + + return widget, controller diff --git a/src/spikeinterface/widgets/matplotlib_utils.py b/src/spikeinterface/widgets/matplotlib_utils.py new file mode 100644 index 0000000000..6ccaaf5840 --- /dev/null +++ b/src/spikeinterface/widgets/matplotlib_utils.py @@ -0,0 +1,75 @@ +import matplotlib.pyplot as plt +import numpy as np + + +def make_mpl_figure(figure=None, ax=None, axes=None, ncols=None, num_axes=None, figsize=None, figtitle=None): + """ + figure/ax/axes : only one of then can be not None + """ + if figure is not None: + assert ax is None and axes is None, "figure/ax/axes : only one of then can be not None" + if num_axes is None: + ax = figure.add_subplot(111) + axes = np.array([[ax]]) + else: + assert ncols is not None + axes = [] + nrows = int(np.ceil(num_axes / ncols)) + axes = np.full((nrows, ncols), fill_value=None, dtype=object) + for i in range(num_axes): + ax = figure.add_subplot(nrows, ncols, i + 1) + r = i // ncols + c = i % ncols + axes[r, c] = ax + elif ax is not None: + assert figure is None and axes is None, "figure/ax/axes : only one of then can be not None" + figure = ax.get_figure() + axes = np.array([[ax]]) + elif axes is not None: + assert figure is None and ax is None, "figure/ax/axes : only one of then can be not None" + axes = np.asarray(axes) + figure = axes.flatten()[0].get_figure() + else: + # 'figure/ax/axes are all None + if num_axes is None: + # one fig with one ax + figure, ax = plt.subplots(figsize=figsize) + axes = np.array([[ax]]) + else: + if num_axes == 0: + # one figure without plots (diffred subplot creation with + figure = plt.figure(figsize=figsize) + ax = None + axes = None + elif num_axes == 1: + figure = plt.figure(figsize=figsize) + ax = figure.add_subplot(111) + axes = np.array([[ax]]) + else: + assert ncols is not None + if num_axes < ncols: + ncols = num_axes + nrows = int(np.ceil(num_axes / ncols)) + figure, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize) + ax = None + # remove extra axes + if ncols * nrows > num_axes: + for i, extra_ax in enumerate(axes.flatten()): + if i >= num_axes: + extra_ax.remove() + r = i // ncols + c = i % ncols + axes[r, c] = None + + if figtitle is not None: + figure.suptitle(figtitle) + + return figure, axes, ax + + # self.figure = figure + # self.ax = ax + # axes is always a 2D array of ax + # self.axes = axes + + # if figtitle is not None: + # self.figure.suptitle(figtitle) \ No newline at end of file diff --git a/src/spikeinterface/widgets/sortingview_utils.py b/src/spikeinterface/widgets/sortingview_utils.py new file mode 100644 index 0000000000..8a4a8f3169 --- /dev/null +++ b/src/spikeinterface/widgets/sortingview_utils.py @@ -0,0 +1,95 @@ +import numpy as np + +from ..core.core_tools import check_json + + + + +sortingview_backend_kwargs_desc = { + "generate_url": "If True, the figurl URL is generated and printed. Default True", + "display": "If True and in jupyter notebook/lab, the widget is displayed in the cell. Default True.", + "figlabel": "The figurl figure label. Default None", + "height": "The height of the sortingview View in jupyter. Default None", +} +sortingview_default_backend_kwargs = {"generate_url": True, "display": True, "figlabel": None, "height": None} + + + +def make_serializable(*args): + dict_to_serialize = {int(i): a for i, a in enumerate(args[1:])} + serializable_dict = check_json(dict_to_serialize) + returns = () + for i in range(len(args) - 1): + returns += (serializable_dict[str(i)],) + if len(returns) == 1: + returns = returns[0] + return returns + +def is_notebook() -> bool: + try: + shell = get_ipython().__class__.__name__ + if shell == "ZMQInteractiveShell": + return True # Jupyter notebook or qtconsole + elif shell == "TerminalInteractiveShell": + return False # Terminal running IPython + else: + return False # Other type (?) + except NameError: + return False + +def handle_display_and_url(widget, view, **backend_kwargs): + url = None + if is_notebook() and backend_kwargs["display"]: + display(view.jupyter(height=backend_kwargs["height"])) + if backend_kwargs["generate_url"]: + figlabel = backend_kwargs.get("figlabel") + if figlabel is None: + figlabel = widget.default_label + url = view.url(label=figlabel) + print(url) + + return url + + + + +def generate_unit_table_view(sorting, unit_properties=None, similarity_scores=None): + import sortingview.views as vv + + if unit_properties is None: + ut_columns = [] + ut_rows = [vv.UnitsTableRow(unit_id=u, values={}) for u in sorting.unit_ids] + else: + ut_columns = [] + ut_rows = [] + values = {} + valid_unit_properties = [] + for prop_name in unit_properties: + property_values = sorting.get_property(prop_name) + # make dtype available + val0 = np.array(property_values[0]) + if val0.dtype.kind in ("i", "u"): + dtype = "int" + elif val0.dtype.kind in ("U", "S"): + dtype = "str" + elif val0.dtype.kind == "f": + dtype = "float" + elif val0.dtype.kind == "b": + dtype = "bool" + else: + print(f"Unsupported dtype {val0.dtype} for property {prop_name}. Skipping") + continue + ut_columns.append(vv.UnitsTableColumn(key=prop_name, label=prop_name, dtype=dtype)) + valid_unit_properties.append(prop_name) + + for ui, unit in enumerate(sorting.unit_ids): + for prop_name in valid_unit_properties: + property_values = sorting.get_property(prop_name) + val0 = property_values[0] + if np.isnan(property_values[ui]): + continue + values[prop_name] = property_values[ui] + ut_rows.append(vv.UnitsTableRow(unit_id=unit, values=check_json(values))) + + v_units_table = vv.UnitsTable(rows=ut_rows, columns=ut_columns, similarity_scores=similarity_scores) + return v_units_table From 84170e37b929ee835bb084e93e4b53d5b168178b Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 19 Jul 2023 12:37:42 +0200 Subject: [PATCH 05/31] wip refactor widgets --- src/spikeinterface/widgets/base.py | 3 ++- src/spikeinterface/widgets/sortingview_utils.py | 7 ++++--- src/spikeinterface/widgets/tests/test_widgets.py | 1 + src/spikeinterface/widgets/unit_locations.py | 9 +++++---- 4 files changed, 12 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/widgets/base.py b/src/spikeinterface/widgets/base.py index 17903b495b..f95004efb9 100644 --- a/src/spikeinterface/widgets/base.py +++ b/src/spikeinterface/widgets/base.py @@ -73,6 +73,7 @@ def __init__(self, data_plot=None, backend=None, immediate_plot=True, **backend_ self.backend_kwargs = backend_kwargs_ if immediate_plot: + print('immediate_plot', self.backend, self.backend_kwargs) self.do_plot(self.backend, **self.backend_kwargs) @classmethod @@ -101,7 +102,7 @@ def do_plot(self, backend, **backend_kwargs): # backend = self.check_backend(backend) func = getattr(self, f'plot_{backend}') - func(data_plot=self.data_plot, **self.backend_kwargs) + func(self.data_plot, **self.backend_kwargs) # @classmethod # def register_backend(cls, backend_plotter): diff --git a/src/spikeinterface/widgets/sortingview_utils.py b/src/spikeinterface/widgets/sortingview_utils.py index 8a4a8f3169..90dfcb77a3 100644 --- a/src/spikeinterface/widgets/sortingview_utils.py +++ b/src/spikeinterface/widgets/sortingview_utils.py @@ -16,10 +16,10 @@ def make_serializable(*args): - dict_to_serialize = {int(i): a for i, a in enumerate(args[1:])} + dict_to_serialize = {int(i): a for i, a in enumerate(args)} serializable_dict = check_json(dict_to_serialize) returns = () - for i in range(len(args) - 1): + for i in range(len(args)): returns += (serializable_dict[str(i)],) if len(returns) == 1: returns = returns[0] @@ -44,7 +44,8 @@ def handle_display_and_url(widget, view, **backend_kwargs): if backend_kwargs["generate_url"]: figlabel = backend_kwargs.get("figlabel") if figlabel is None: - figlabel = widget.default_label + # figlabel = widget.default_label + figlabel = "" url = view.url(label=figlabel) print(url) diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index cb4341f044..1dff04d334 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -36,6 +36,7 @@ else: cache_folder = Path("cache_folder") / "widgets" +print(cache_folder) ON_GITHUB = bool(os.getenv("GITHUB_ACTIONS")) KACHERY_CLOUD_SET = bool(os.getenv("KACHERY_CLOUD_CLIENT_ID")) and bool(os.getenv("KACHERY_CLOUD_PRIVATE_KEY")) diff --git a/src/spikeinterface/widgets/unit_locations.py b/src/spikeinterface/widgets/unit_locations.py index 036158cda7..e87f553072 100644 --- a/src/spikeinterface/widgets/unit_locations.py +++ b/src/spikeinterface/widgets/unit_locations.py @@ -83,7 +83,6 @@ def __init__( BaseWidget.__init__(self, data_plot, backend=backend, **backend_kwargs) def plot_matplotlib(self, data_plot, **backend_kwargs): - print(data_plot, backend_kwargs) import matplotlib.pyplot as plt from .matplotlib_utils import make_mpl_figure from probeinterface.plotting import plot_probe @@ -99,7 +98,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): # self.make_mpl_figure(**backend_kwargs) - self.figure, self.axes, self.ax = make_mpl_figure(backend_kwargs) + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) unit_locations = dp.unit_locations @@ -180,6 +179,8 @@ def plot_sortingview(self, data_plot, **backend_kwargs): dp = to_attr(data_plot) # ensure serializable for sortingview + print(dp.unit_ids, dp.channel_ids) + print(make_serializable(dp.unit_ids, dp.channel_ids)) unit_ids, channel_ids = make_serializable(dp.unit_ids, dp.channel_ids) locations = {str(ch): dp.channel_locations[i_ch].astype("float32") for i_ch, ch in enumerate(channel_ids)} @@ -256,10 +257,10 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): ) # a first update - self.updater(None) + self.update_widget(None) if backend_kwargs["display"]: - self.check_backend() + # self.check_backend() display(self.widget) def update_widget(self, change): From 5f9e0c9f1e558e1f27aab39aae3c5a955bb144a0 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 19 Jul 2023 17:03:17 +0200 Subject: [PATCH 06/31] widget refactor : AllAmplitudesDistributionsWidget and AmplitudesWidget --- .../widgets/all_amplitudes_distributions.py | 41 +++- src/spikeinterface/widgets/amplitudes.py | 183 +++++++++++++++++- src/spikeinterface/widgets/base.py | 7 +- .../widgets/tests/test_widgets.py | 3 +- src/spikeinterface/widgets/unit_locations.py | 78 ++++---- src/spikeinterface/widgets/widget_list.py | 11 +- 6 files changed, 273 insertions(+), 50 deletions(-) diff --git a/src/spikeinterface/widgets/all_amplitudes_distributions.py b/src/spikeinterface/widgets/all_amplitudes_distributions.py index d1a0acfe1e..18585a4f96 100644 --- a/src/spikeinterface/widgets/all_amplitudes_distributions.py +++ b/src/spikeinterface/widgets/all_amplitudes_distributions.py @@ -1,7 +1,7 @@ import numpy as np from warnings import warn -from .base import BaseWidget +from .base import BaseWidget, to_attr from .utils import get_some_colors from ..core.waveform_extractor import WaveformExtractor @@ -47,3 +47,42 @@ def __init__( ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .matplotlib_utils import make_mpl_figure + + from matplotlib.patches import Ellipse + from matplotlib.lines import Line2D + + + dp = to_attr(data_plot) + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + # self.make_mpl_figure(**backend_kwargs) + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + ax = self.ax + + unit_amps = [] + for i, unit_id in enumerate(dp.unit_ids): + amps = [] + for segment_index in range(dp.num_segments): + amps.append(dp.amplitudes[segment_index][unit_id]) + amps = np.concatenate(amps) + unit_amps.append(amps) + parts = ax.violinplot(unit_amps, showmeans=False, showmedians=False, showextrema=False) + + for i, pc in enumerate(parts["bodies"]): + color = dp.unit_colors[dp.unit_ids[i]] + pc.set_facecolor(color) + pc.set_edgecolor("black") + pc.set_alpha(1) + + ax.set_xticks(np.arange(len(dp.unit_ids)) + 1) + ax.set_xticklabels([str(unit_id) for unit_id in dp.unit_ids]) + + ylims = ax.get_ylim() + if np.max(ylims) < 0: + ax.set_ylim(min(ylims), 0) + if np.min(ylims) > 0: + ax.set_ylim(0, max(ylims)) \ No newline at end of file diff --git a/src/spikeinterface/widgets/amplitudes.py b/src/spikeinterface/widgets/amplitudes.py index 833bdf2b06..7c76d26204 100644 --- a/src/spikeinterface/widgets/amplitudes.py +++ b/src/spikeinterface/widgets/amplitudes.py @@ -1,7 +1,7 @@ import numpy as np from warnings import warn -from .base import BaseWidget +from .base import BaseWidget, to_attr from .utils import get_some_colors from ..core.waveform_extractor import WaveformExtractor @@ -112,3 +112,184 @@ def __init__( ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .matplotlib_utils import make_mpl_figure + from probeinterface.plotting import plot_probe + + from matplotlib.patches import Ellipse + from matplotlib.lines import Line2D + + + + + dp = to_attr(data_plot) + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + + if backend_kwargs["axes"] is not None: + axes = backend_kwargs["axes"] + if dp.plot_histograms: + assert np.asarray(axes).size == 2 + else: + assert np.asarray(axes).size == 1 + elif backend_kwargs["ax"] is not None: + assert not dp.plot_histograms + else: + if dp.plot_histograms: + backend_kwargs["num_axes"] = 2 + backend_kwargs["ncols"] = 2 + else: + backend_kwargs["num_axes"] = None + + # self.make_mpl_figure(**backend_kwargs) + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + scatter_ax = self.axes.flatten()[0] + + for unit_id in dp.unit_ids: + spiketrains = dp.spiketrains[unit_id] + amps = dp.amplitudes[unit_id] + scatter_ax.scatter(spiketrains, amps, color=dp.unit_colors[unit_id], s=3, alpha=1, label=unit_id) + + if dp.plot_histograms: + if dp.bins is None: + bins = int(len(spiketrains) / 30) + else: + bins = dp.bins + ax_hist = self.axes.flatten()[1] + ax_hist.hist(amps, bins=bins, orientation="horizontal", color=dp.unit_colors[unit_id], alpha=0.8) + + if dp.plot_histograms: + ax_hist = self.axes.flatten()[1] + ax_hist.set_ylim(scatter_ax.get_ylim()) + ax_hist.axis("off") + self.figure.tight_layout() + + if dp.plot_legend: + # if self.legend is not None: + if hasattr(self, 'legend') and self.legend is not None: + self.legend.remove() + self.legend = self.figure.legend( + loc="upper center", bbox_to_anchor=(0.5, 1.0), ncol=5, fancybox=True, shadow=True + ) + + scatter_ax.set_xlim(0, dp.total_duration) + scatter_ax.set_xlabel("Times [s]") + scatter_ax.set_ylabel(f"Amplitude") + scatter_ax.spines["top"].set_visible(False) + scatter_ax.spines["right"].set_visible(False) + self.figure.subplots_adjust(bottom=0.1, top=0.9, left=0.1) + + def plot_ipywidgets(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + import ipywidgets.widgets as widgets + from IPython.display import display + from .ipywidgets_utils import check_ipywidget_backend, make_unit_controller + + check_ipywidget_backend() + + self.next_data_plot = data_plot.copy() + + cm = 1 / 2.54 + we = data_plot["waveform_extractor"] + + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + width_cm = backend_kwargs["width_cm"] + height_cm = backend_kwargs["height_cm"] + + ratios = [0.15, 0.85] + + with plt.ioff(): + output = widgets.Output() + with output: + # fig = plt.figure(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) + self.figure = plt.figure(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) + plt.show() + + data_plot["unit_ids"] = data_plot["unit_ids"][:1] + unit_widget, unit_controller = make_unit_controller( + data_plot["unit_ids"], we.unit_ids, ratios[0] * width_cm, height_cm + ) + + plot_histograms = widgets.Checkbox( + value=data_plot["plot_histograms"], + description="plot histograms", + disabled=False, + ) + + footer = plot_histograms + + self.controller = {"plot_histograms": plot_histograms} + self.controller.update(unit_controller) + + # mpl_plotter = MplAmplitudesPlotter() + + # self.updater = PlotUpdater(data_plot, mpl_plotter, fig, self.controller) + for w in self.controller.values(): + # w.observe(self.updater) + w.observe(self._update_ipywidget) + + self.widget = widgets.AppLayout( + # center=fig.canvas, left_sidebar=unit_widget, pane_widths=ratios + [0], footer=footer + center=self.figure.canvas, left_sidebar=unit_widget, pane_widths=ratios + [0], footer=footer + ) + + # a first update + # self.updater(None) + self._update_ipywidget(None) + + if backend_kwargs["display"]: + # self.check_backend() + display(self.widget) + + def _update_ipywidget(self, change): + # self.fig.clear() + self.figure.clear() + + unit_ids = self.controller["unit_ids"].value + plot_histograms = self.controller["plot_histograms"].value + + # matplotlib next_data_plot dict update at each call + data_plot = self.next_data_plot + data_plot["unit_ids"] = unit_ids + data_plot["plot_histograms"] = plot_histograms + + backend_kwargs = {} + # backend_kwargs["figure"] = self.fig + backend_kwargs["figure"] = self.figure + backend_kwargs["axes"] = None + backend_kwargs["ax"] = None + + # self.mpl_plotter.do_plot(data_plot, **backend_kwargs) + self.plot_matplotlib(data_plot, **backend_kwargs) + + self.figure.canvas.draw() + self.figure.canvas.flush_events() + + def plot_sortingview(self, data_plot, **backend_kwargs): + import sortingview.views as vv + from .sortingview_utils import generate_unit_table_view, make_serializable, handle_display_and_url + + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + dp = to_attr(data_plot) + + # unit_ids = self.make_serializable(dp.unit_ids) + unit_ids = make_serializable(dp.unit_ids) + + sa_items = [ + vv.SpikeAmplitudesItem( + unit_id=u, + spike_times_sec=dp.spiketrains[u].astype("float32"), + spike_amplitudes=dp.amplitudes[u].astype("float32"), + ) + for u in unit_ids + ] + + # v_spike_amplitudes = vv.SpikeAmplitudes( + self.view = vv.SpikeAmplitudes( + start_time_sec=0, end_time_sec=dp.total_duration, plots=sa_items, hide_unit_selector=dp.hide_unit_selector + ) + + # self.handle_display_and_url(v_spike_amplitudes, **backend_kwargs) + self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) diff --git a/src/spikeinterface/widgets/base.py b/src/spikeinterface/widgets/base.py index f95004efb9..7b0ba0454e 100644 --- a/src/spikeinterface/widgets/base.py +++ b/src/spikeinterface/widgets/base.py @@ -58,16 +58,17 @@ class BaseWidget: def __init__(self, data_plot=None, backend=None, immediate_plot=True, **backend_kwargs, ): # every widgets must prepare a dict "plot_data" in the init self.data_plot = data_plot - self.backend = self.check_backend(backend) + backend = self.check_backend(backend) + self.backend = backend # check backend kwargs for k in backend_kwargs: if k not in default_backend_kwargs[backend]: raise Exception( f"{k} is not a valid plot argument or backend keyword argument. " - f"Possible backend keyword arguments for {backend} are: {list(plotter_kwargs.keys())}" + f"Possible backend keyword arguments for {backend} are: {list(default_backend_kwargs[backend].keys())}" ) - backend_kwargs_ = default_backend_kwargs[backend].copy() + backend_kwargs_ = default_backend_kwargs[self.backend].copy() backend_kwargs_.update(backend_kwargs) self.backend_kwargs = backend_kwargs_ diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 1dff04d334..4ddec4134b 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -344,9 +344,10 @@ def test_sorting_summary(self): # mytest.test_plot_unit_depths() # mytest.test_plot_unit_templates() # mytest.test_plot_unit_summary() - mytest.test_unit_locations() + # mytest.test_unit_locations() # mytest.test_quality_metrics() # mytest.test_template_metrics() + mytest.test_amplitudes() # plt.ion() plt.show() diff --git a/src/spikeinterface/widgets/unit_locations.py b/src/spikeinterface/widgets/unit_locations.py index e87f553072..725a4c3023 100644 --- a/src/spikeinterface/widgets/unit_locations.py +++ b/src/spikeinterface/widgets/unit_locations.py @@ -171,51 +171,17 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): if dp.hide_axis: self.ax.axis("off") - def plot_sortingview(self, data_plot, **backend_kwargs): - import sortingview.views as vv - from .sortingview_utils import generate_unit_table_view, make_serializable, handle_display_and_url - - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - dp = to_attr(data_plot) - - # ensure serializable for sortingview - print(dp.unit_ids, dp.channel_ids) - print(make_serializable(dp.unit_ids, dp.channel_ids)) - unit_ids, channel_ids = make_serializable(dp.unit_ids, dp.channel_ids) - - locations = {str(ch): dp.channel_locations[i_ch].astype("float32") for i_ch, ch in enumerate(channel_ids)} - - unit_items = [] - for unit_id in unit_ids: - unit_items.append( - vv.UnitLocationsItem( - unit_id=unit_id, x=float(dp.unit_locations[unit_id][0]), y=float(dp.unit_locations[unit_id][1]) - ) - ) - v_unit_locations = vv.UnitLocations(units=unit_items, channel_locations=locations, disable_auto_rotate=True) - - if not dp.hide_unit_selector: - v_units_table = generate_unit_table_view(dp.sorting) - - self.view = vv.Box( - direction="horizontal", - items=[vv.LayoutItem(v_units_table, max_size=150), vv.LayoutItem(v_unit_locations)], - ) - else: - self.view = v_unit_locations - - # self.handle_display_and_url(view, **backend_kwargs) - self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt import ipywidgets.widgets as widgets from IPython.display import display - from .ipywidgets_utils import check_ipywidget_backend, make_unit_controller + check_ipywidget_backend() + self.next_data_plot = data_plot.copy() cm = 1 / 2.54 @@ -248,7 +214,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): # w.observe(self.updater) for w in self.controller.values(): - w.observe(self.update_widget) + w.observe(self._update_ipywidget) self.widget = widgets.AppLayout( center=fig.canvas, @@ -257,13 +223,12 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): ) # a first update - self.update_widget(None) + self._update_ipywidget(None) if backend_kwargs["display"]: - # self.check_backend() display(self.widget) - def update_widget(self, change): + def _update_ipywidget(self, change): self.ax.clear() unit_ids = self.controller["unit_ids"].value @@ -284,5 +249,38 @@ def update_widget(self, change): fig.canvas.draw() fig.canvas.flush_events() + def plot_sortingview(self, data_plot, **backend_kwargs): + import sortingview.views as vv + from .sortingview_utils import generate_unit_table_view, make_serializable, handle_display_and_url + + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + dp = to_attr(data_plot) + + # ensure serializable for sortingview + unit_ids, channel_ids = make_serializable(dp.unit_ids, dp.channel_ids) + + locations = {str(ch): dp.channel_locations[i_ch].astype("float32") for i_ch, ch in enumerate(channel_ids)} + + unit_items = [] + for unit_id in unit_ids: + unit_items.append( + vv.UnitLocationsItem( + unit_id=unit_id, x=float(dp.unit_locations[unit_id][0]), y=float(dp.unit_locations[unit_id][1]) + ) + ) + + v_unit_locations = vv.UnitLocations(units=unit_items, channel_locations=locations, disable_auto_rotate=True) + if not dp.hide_unit_selector: + v_units_table = generate_unit_table_view(dp.sorting) + + self.view = vv.Box( + direction="horizontal", + items=[vv.LayoutItem(v_units_table, max_size=150), vv.LayoutItem(v_unit_locations)], + ) + else: + self.view = v_unit_locations + + # self.handle_display_and_url(view, **backend_kwargs) + self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index 53f2e7eb62..52ee03ebca 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -34,8 +34,8 @@ # correlogram comparison # amplitudes -# from .amplitudes import AmplitudesWidget -# from .all_amplitudes_distributions import AllAmplitudesDistributionsWidget +from .amplitudes import AmplitudesWidget +from .all_amplitudes_distributions import AllAmplitudesDistributionsWidget # metrics # from .quality_metrics import QualityMetricsWidget @@ -57,8 +57,8 @@ widget_list = [ - # AmplitudesWidget, - # AllAmplitudesDistributionsWidget, + AmplitudesWidget, + AllAmplitudesDistributionsWidget, # AutoCorrelogramsWidget, # CrossCorrelogramsWidget, # QualityMetricsWidget, @@ -129,4 +129,7 @@ # plot_sorting_summary = define_widget_function_from_class(SortingSummaryWidget, "plot_sorting_summary") +plot_amplitudes = AmplitudesWidget +plot_all_amplitudes_distributions = AllAmplitudesDistributionsWidget plot_unit_locations = UnitLocationsWidget + From 4170552e8947920743c5445cb7f093968148f547 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 19 Jul 2023 17:21:25 +0200 Subject: [PATCH 07/31] refactor widgets : AutoCorrelogramsWidget + CrossCorrelogramsWidget --- .../widgets/autocorrelograms.py | 59 +++++++++++++++- .../widgets/crosscorrelograms.py | 70 ++++++++++++++++++- src/spikeinterface/widgets/widget_list.py | 8 ++- 3 files changed, 131 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/widgets/autocorrelograms.py b/src/spikeinterface/widgets/autocorrelograms.py index 701817e168..f07246efa6 100644 --- a/src/spikeinterface/widgets/autocorrelograms.py +++ b/src/spikeinterface/widgets/autocorrelograms.py @@ -1,11 +1,68 @@ +from .base import BaseWidget, to_attr + from .crosscorrelograms import CrossCorrelogramsWidget class AutoCorrelogramsWidget(CrossCorrelogramsWidget): - possible_backends = {} + # possible_backends = {} def __init__(self, *args, **kargs): CrossCorrelogramsWidget.__init__(self, *args, **kargs) + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .matplotlib_utils import make_mpl_figure + + dp = to_attr(data_plot) + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + backend_kwargs["num_axes"] = len(dp.unit_ids) + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + # self.make_mpl_figure(**backend_kwargs) + + bins = dp.bins + unit_ids = dp.unit_ids + correlograms = dp.correlograms + bin_width = bins[1] - bins[0] + + for i, unit_id in enumerate(unit_ids): + ccg = correlograms[i, i] + ax = self.axes.flatten()[i] + if dp.unit_colors is None: + color = "g" + else: + color = dp.unit_colors[unit_id] + ax.bar(x=bins[:-1], height=ccg, width=bin_width, color=color, align="edge") + ax.set_title(str(unit_id)) + + def plot_sortingview(self, data_plot, **backend_kwargs): + import sortingview.views as vv + from .sortingview_utils import make_serializable, handle_display_and_url + + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + dp = to_attr(data_plot) + # unit_ids = self.make_serializable(dp.unit_ids) + unit_ids = make_serializable(dp.unit_ids) + + ac_items = [] + for i in range(len(unit_ids)): + for j in range(i, len(unit_ids)): + if i == j: + ac_items.append( + vv.AutocorrelogramItem( + unit_id=unit_ids[i], + bin_edges_sec=(dp.bins / 1000.0).astype("float32"), + bin_counts=dp.correlograms[i, j].astype("int32"), + ) + ) + + self.view = vv.Autocorrelograms(autocorrelograms=ac_items) + + # self.handle_display_and_url(v_autocorrelograms, **backend_kwargs) + # return v_autocorrelograms + self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) + + + AutoCorrelogramsWidget.__doc__ = CrossCorrelogramsWidget.__doc__ diff --git a/src/spikeinterface/widgets/crosscorrelograms.py b/src/spikeinterface/widgets/crosscorrelograms.py index 8481c8ef0d..eed76c3e04 100644 --- a/src/spikeinterface/widgets/crosscorrelograms.py +++ b/src/spikeinterface/widgets/crosscorrelograms.py @@ -1,7 +1,7 @@ import numpy as np from typing import Union -from .base import BaseWidget +from .base import BaseWidget, to_attr from ..core.waveform_extractor import WaveformExtractor from ..core.basesorting import BaseSorting from ..postprocessing import compute_correlograms @@ -27,7 +27,7 @@ class CrossCorrelogramsWidget(BaseWidget): If given, a dictionary with unit ids as keys and colors as values, default None """ - possible_backends = {} + # possible_backends = {} def __init__( self, @@ -65,3 +65,69 @@ def __init__( ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .matplotlib_utils import make_mpl_figure + + dp = to_attr(data_plot) + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + backend_kwargs["ncols"] = len(dp.unit_ids) + backend_kwargs["num_axes"] = int(len(dp.unit_ids) ** 2) + + # self.make_mpl_figure(**backend_kwargs) + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + assert self.axes.ndim == 2 + + bins = dp.bins + unit_ids = dp.unit_ids + correlograms = dp.correlograms + bin_width = bins[1] - bins[0] + + for i, unit_id1 in enumerate(unit_ids): + for j, unit_id2 in enumerate(unit_ids): + ccg = correlograms[i, j] + ax = self.axes[i, j] + if i == j: + if dp.unit_colors is None: + color = "g" + else: + color = dp.unit_colors[unit_id1] + else: + color = "k" + ax.bar(x=bins[:-1], height=ccg, width=bin_width, color=color, align="edge") + + for i, unit_id in enumerate(unit_ids): + self.axes[0, i].set_title(str(unit_id)) + self.axes[-1, i].set_xlabel("CCG (ms)") + + def plot_sortingview(self, data_plot, **backend_kwargs): + import sortingview.views as vv + from .sortingview_utils import generate_unit_table_view, make_serializable, handle_display_and_url + + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + dp = to_attr(data_plot) + + # unit_ids = self.make_serializable(dp.unit_ids) + unit_ids = make_serializable(dp.unit_ids) + + cc_items = [] + for i in range(len(unit_ids)): + for j in range(i, len(unit_ids)): + cc_items.append( + vv.CrossCorrelogramItem( + unit_id1=unit_ids[i], + unit_id2=unit_ids[j], + bin_edges_sec=(dp.bins / 1000.0).astype("float32"), + bin_counts=dp.correlograms[i, j].astype("int32"), + ) + ) + + self.view = vv.CrossCorrelograms( + cross_correlograms=cc_items, hide_unit_selector=dp.hide_unit_selector + ) + + # self.handle_display_and_url(v_cross_correlograms, **backend_kwargs) + # return v_cross_correlograms + self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index 52ee03ebca..fb3a611c60 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -10,8 +10,8 @@ # from .unit_waveforms_density_map import UnitWaveformDensityMapWidget # isi/ccg/acg -# from .autocorrelograms import AutoCorrelogramsWidget -# from .crosscorrelograms import CrossCorrelogramsWidget +from .autocorrelograms import AutoCorrelogramsWidget +from .crosscorrelograms import CrossCorrelogramsWidget # peak activity @@ -59,7 +59,7 @@ widget_list = [ AmplitudesWidget, AllAmplitudesDistributionsWidget, - # AutoCorrelogramsWidget, + AutoCorrelogramsWidget, # CrossCorrelogramsWidget, # QualityMetricsWidget, # SpikeLocationsWidget, @@ -132,4 +132,6 @@ plot_amplitudes = AmplitudesWidget plot_all_amplitudes_distributions = AllAmplitudesDistributionsWidget plot_unit_locations = UnitLocationsWidget +plot_autocorrelograms = AutoCorrelogramsWidget +plot_crosscorrelograms = CrossCorrelogramsWidget From 77e2c1fe5632f17df4504b94317248e6df284b80 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 19 Jul 2023 17:33:45 +0200 Subject: [PATCH 08/31] refactor widget : SpikeLocationsWidget --- src/spikeinterface/widgets/spike_locations.py | 231 +++++++++++++++++- src/spikeinterface/widgets/widget_list.py | 5 +- 2 files changed, 232 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/widgets/spike_locations.py b/src/spikeinterface/widgets/spike_locations.py index da5ad5b08c..d32c3c2f4c 100644 --- a/src/spikeinterface/widgets/spike_locations.py +++ b/src/spikeinterface/widgets/spike_locations.py @@ -1,7 +1,7 @@ import numpy as np from typing import Union -from .base import BaseWidget +from .base import BaseWidget, to_attr from .utils import get_unit_colors from ..core.waveform_extractor import WaveformExtractor @@ -36,7 +36,7 @@ class SpikeLocationsWidget(BaseWidget): If True, the axis is set to off. Default False (matplotlib backend) """ - possible_backends = {} + # possible_backends = {} def __init__( self, @@ -105,6 +105,233 @@ def __init__( BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .matplotlib_utils import make_mpl_figure + from matplotlib.lines import Line2D + + from probeinterface import ProbeGroup + from probeinterface.plotting import plot_probe + + dp = to_attr(data_plot) + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + + # self.make_mpl_figure(**backend_kwargs) + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + spike_locations = dp.spike_locations + + probegroup = ProbeGroup.from_dict(dp.probegroup_dict) + probe_shape_kwargs = dict(facecolor="w", edgecolor="k", lw=0.5, alpha=1.0) + contacts_kargs = dict(alpha=1.0, edgecolor="k", lw=0.5) + + for probe in probegroup.probes: + text_on_contact = None + if dp.with_channel_ids: + text_on_contact = dp.channel_ids + + poly_contact, poly_contour = plot_probe( + probe, + ax=self.ax, + contacts_colors="w", + contacts_kargs=contacts_kargs, + probe_shape_kwargs=probe_shape_kwargs, + text_on_contact=text_on_contact, + ) + poly_contact.set_zorder(2) + if poly_contour is not None: + poly_contour.set_zorder(1) + + self.ax.set_title("") + + if dp.plot_all_units: + unit_colors = {} + unit_ids = dp.all_unit_ids + for unit in dp.all_unit_ids: + if unit not in dp.unit_ids: + unit_colors[unit] = "gray" + else: + unit_colors[unit] = dp.unit_colors[unit] + else: + unit_ids = dp.unit_ids + unit_colors = dp.unit_colors + labels = dp.unit_ids + + for i, unit in enumerate(unit_ids): + locs = spike_locations[unit] + + zorder = 5 if unit in dp.unit_ids else 3 + self.ax.scatter(locs["x"], locs["y"], s=2, alpha=0.3, color=unit_colors[unit], zorder=zorder) + + handles = [ + Line2D([0], [0], ls="", marker="o", markersize=5, markeredgewidth=2, color=unit_colors[unit]) + for unit in dp.unit_ids + ] + if dp.plot_legend: + # if self.legend is not None: + if hasattr(self, 'legend') and self.legend is not None: + self.legend.remove() + self.legend = self.figure.legend( + handles, labels, loc="upper center", bbox_to_anchor=(0.5, 1.0), ncol=5, fancybox=True, shadow=True + ) + + # set proper axis limits + xlims, ylims = estimate_axis_lims(spike_locations) + + ax_xlims = list(self.ax.get_xlim()) + ax_ylims = list(self.ax.get_ylim()) + + ax_xlims[0] = xlims[0] if xlims[0] < ax_xlims[0] else ax_xlims[0] + ax_xlims[1] = xlims[1] if xlims[1] > ax_xlims[1] else ax_xlims[1] + ax_ylims[0] = ylims[0] if ylims[0] < ax_ylims[0] else ax_ylims[0] + ax_ylims[1] = ylims[1] if ylims[1] > ax_ylims[1] else ax_ylims[1] + + self.ax.set_xlim(ax_xlims) + self.ax.set_ylim(ax_ylims) + if dp.hide_axis: + self.ax.axis("off") + + def plot_ipywidgets(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + import ipywidgets.widgets as widgets + from IPython.display import display + from .ipywidgets_utils import check_ipywidget_backend, make_unit_controller + + check_ipywidget_backend() + + self.next_data_plot = data_plot.copy() + + cm = 1 / 2.54 + + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + width_cm = backend_kwargs["width_cm"] + height_cm = backend_kwargs["height_cm"] + + ratios = [0.15, 0.85] + + with plt.ioff(): + output = widgets.Output() + with output: + fig, self.ax = plt.subplots(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) + plt.show() + + data_plot["unit_ids"] = data_plot["unit_ids"][:1] + + unit_widget, unit_controller = make_unit_controller( + data_plot["unit_ids"], + list(data_plot["unit_colors"].keys()), + ratios[0] * width_cm, + height_cm, + ) + + self.controller = unit_controller + + # mpl_plotter = MplSpikeLocationsPlotter() + + # self.updater = PlotUpdater(data_plot, mpl_plotter, ax, self.controller) + # for w in self.controller.values(): + # w.observe(self.updater) + + for w in self.controller.values(): + w.observe(self._update_ipywidget) + + self.widget = widgets.AppLayout( + center=fig.canvas, + left_sidebar=unit_widget, + pane_widths=ratios + [0], + ) + + # a first update + # self.updater(None) + self._update_ipywidget(None) + + + if backend_kwargs["display"]: + # self.check_backend() + display(self.widget) + + def _update_ipywidget(self, change): + + self.ax.clear() + + unit_ids = self.controller["unit_ids"].value + + # matplotlib next_data_plot dict update at each call + data_plot = self.next_data_plot + data_plot["unit_ids"] = unit_ids + data_plot["plot_all_units"] = True + data_plot["plot_legend"] = True + data_plot["hide_axis"] = True + + backend_kwargs = {} + backend_kwargs["ax"] = self.ax + + # self.mpl_plotter.do_plot(data_plot, **backend_kwargs) + self.plot_matplotlib(data_plot, **backend_kwargs) + fig = self.ax.get_figure() + fig.canvas.draw() + fig.canvas.flush_events() + + + def plot_sortingview(self, data_plot, **backend_kwargs): + import sortingview.views as vv + from .sortingview_utils import generate_unit_table_view, make_serializable, handle_display_and_url + + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + dp = to_attr(data_plot) + spike_locations = dp.spike_locations + + # ensure serializable for sortingview + # unit_ids, channel_ids = self.make_serializable(dp.unit_ids, dp.channel_ids) + unit_ids, channel_ids = make_serializable(dp.unit_ids, dp.channel_ids) + + locations = {str(ch): dp.channel_locations[i_ch].astype("float32") for i_ch, ch in enumerate(channel_ids)} + xlims, ylims = estimate_axis_lims(spike_locations) + + unit_items = [] + for unit in unit_ids: + spike_times_sec = dp.sorting.get_unit_spike_train( + unit_id=unit, segment_index=dp.segment_index, return_times=True + ) + unit_items.append( + vv.SpikeLocationsItem( + unit_id=unit, + spike_times_sec=spike_times_sec.astype("float32"), + x_locations=spike_locations[unit]["x"].astype("float32"), + y_locations=spike_locations[unit]["y"].astype("float32"), + ) + ) + + v_spike_locations = vv.SpikeLocations( + units=unit_items, + hide_unit_selector=dp.hide_unit_selector, + x_range=xlims.astype("float32"), + y_range=ylims.astype("float32"), + channel_locations=locations, + disable_auto_rotate=True, + ) + + if not dp.hide_unit_selector: + v_units_table = generate_unit_table_view(dp.sorting) + + self.view = vv.Box( + direction="horizontal", + items=[ + vv.LayoutItem(v_units_table, max_size=150), + vv.LayoutItem(v_spike_locations), + ], + ) + else: + self.view = v_spike_locations + + # self.set_view(view) + + # self.handle_display_and_url(view, **backend_kwargs) + # return view + self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) + + + def estimate_axis_lims(spike_locations, quantile=0.02): # set proper axis limits diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index fb3a611c60..2a146b52b9 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -24,7 +24,7 @@ # units on probe from .unit_locations import UnitLocationsWidget -# from .spike_locations import SpikeLocationsWidget +from .spike_locations import SpikeLocationsWidget # unit presence @@ -62,7 +62,7 @@ AutoCorrelogramsWidget, # CrossCorrelogramsWidget, # QualityMetricsWidget, - # SpikeLocationsWidget, + SpikeLocationsWidget, # SpikesOnTracesWidget, # TemplateMetricsWidget, # MotionWidget, @@ -134,4 +134,5 @@ plot_unit_locations = UnitLocationsWidget plot_autocorrelograms = AutoCorrelogramsWidget plot_crosscorrelograms = CrossCorrelogramsWidget +plot_spike_locations = SpikeLocationsWidget From 1bdb64f5e0d0a8dda32460efc92a6cd92b6c3e21 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 19 Jul 2023 20:52:03 +0200 Subject: [PATCH 09/31] widget refactor TemplateMetricsWidget QualityMetricsWidget --- src/spikeinterface/widgets/metrics.py | 211 +++++++++++++++++- src/spikeinterface/widgets/quality_metrics.py | 2 +- .../widgets/template_metrics.py | 2 +- src/spikeinterface/widgets/widget_list.py | 10 +- 4 files changed, 217 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/widgets/metrics.py b/src/spikeinterface/widgets/metrics.py index 8e77e4a0f0..207e3a8a6c 100644 --- a/src/spikeinterface/widgets/metrics.py +++ b/src/spikeinterface/widgets/metrics.py @@ -1,8 +1,9 @@ import warnings import numpy as np -from .base import BaseWidget +from .base import BaseWidget, to_attr from .utils import get_unit_colors +from ..core.core_tools import check_json class MetricsBaseWidget(BaseWidget): @@ -29,7 +30,7 @@ class MetricsBaseWidget(BaseWidget): If True, metrics data are included in unit table, by default True """ - possible_backends = {} + # possible_backends = {} def __init__( self, @@ -77,3 +78,209 @@ def __init__( ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .matplotlib_utils import make_mpl_figure + + dp = to_attr(data_plot) + metrics = dp.metrics + num_metrics = len(metrics.columns) + + if "figsize" not in backend_kwargs: + backend_kwargs["figsize"] = (2 * num_metrics, 2 * num_metrics) + + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + backend_kwargs["num_axes"] = num_metrics ** 2 + backend_kwargs["ncols"] = num_metrics + + all_unit_ids = metrics.index.values + + # self.make_mpl_figure(**backend_kwargs) + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + assert self.axes.ndim == 2 + + if dp.unit_ids is None: + colors = ["gray"] * len(all_unit_ids) + else: + colors = [] + for unit in all_unit_ids: + color = "gray" if unit not in dp.unit_ids else dp.unit_colors[unit] + colors.append(color) + + self.patches = [] + for i, m1 in enumerate(metrics.columns): + for j, m2 in enumerate(metrics.columns): + if i == j: + self.axes[i, j].hist(metrics[m1], color="gray") + else: + p = self.axes[i, j].scatter(metrics[m1], metrics[m2], c=colors, s=3, marker="o") + self.patches.append(p) + if i == num_metrics - 1: + self.axes[i, j].set_xlabel(m2, fontsize=10) + if j == 0: + self.axes[i, j].set_ylabel(m1, fontsize=10) + self.axes[i, j].set_xticklabels([]) + self.axes[i, j].set_yticklabels([]) + self.axes[i, j].spines["top"].set_visible(False) + self.axes[i, j].spines["right"].set_visible(False) + + self.figure.subplots_adjust(top=0.8, wspace=0.2, hspace=0.2) + + + def plot_ipywidgets(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + import ipywidgets.widgets as widgets + from IPython.display import display + from .ipywidgets_utils import check_ipywidget_backend, make_unit_controller + + check_ipywidget_backend() + + self.next_data_plot = data_plot.copy() + + cm = 1 / 2.54 + + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + width_cm = backend_kwargs["width_cm"] + height_cm = backend_kwargs["height_cm"] + + ratios = [0.15, 0.85] + + with plt.ioff(): + output = widgets.Output() + with output: + self.figure = plt.figure(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) + plt.show() + if data_plot["unit_ids"] is None: + data_plot["unit_ids"] = [] + + unit_widget, unit_controller = make_unit_controller( + data_plot["unit_ids"], list(data_plot["unit_colors"].keys()), ratios[0] * width_cm, height_cm + ) + + self.controller = unit_controller + + # mpl_plotter = MplMetricsPlotter() + + # self.updater = PlotUpdater(data_plot, mpl_plotter, fig, self.controller) + # for w in self.controller.values(): + # w.observe(self.updater) + for w in self.controller.values(): + w.observe(self._update_ipywidget) + + + self.widget = widgets.AppLayout( + center=self.figure.canvas, + left_sidebar=unit_widget, + pane_widths=ratios + [0], + ) + + # a first update + # self.updater(None) + self._update_ipywidget(None) + + if backend_kwargs["display"]: + # self.check_backend() + display(self.widget) + + def _update_ipywidget(self, change): + from matplotlib.lines import Line2D + + unit_ids = self.controller["unit_ids"].value + + unit_colors = self.data_plot["unit_colors"] + # matplotlib next_data_plot dict update at each call + all_units = list(unit_colors.keys()) + colors = [] + sizes = [] + for unit in all_units: + color = "gray" if unit not in unit_ids else unit_colors[unit] + size = 1 if unit not in unit_ids else 5 + colors.append(color) + sizes.append(size) + + # here we do a trick: we just update colors + # if hasattr(self.mpl_plotter, "patches"): + if hasattr(self, "patches"): + # for p in self.mpl_plotter.patches: + for p in self.patches: + p.set_color(colors) + p.set_sizes(sizes) + else: + backend_kwargs = {} + backend_kwargs["figure"] = self.figure + # self.mpl_plotter.do_plot(self.data_plot, **backend_kwargs) + self.plot_matplotlib(self.data_plot, **backend_kwargs) + + if len(unit_ids) > 0: + for l in self.figure.legends: + l.remove() + handles = [ + Line2D([0], [0], ls="", marker="o", markersize=5, markeredgewidth=2, color=unit_colors[unit]) + for unit in unit_ids + ] + labels = unit_ids + self.figure.legend( + handles, labels, loc="upper center", bbox_to_anchor=(0.5, 1.0), ncol=5, fancybox=True, shadow=True + ) + + self.figure.canvas.draw() + self.figure.canvas.flush_events() + + def plot_sortingview(self, data_plot, **backend_kwargs): + import sortingview.views as vv + from .sortingview_utils import generate_unit_table_view, make_serializable, handle_display_and_url + + dp = to_attr(data_plot) + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + + metrics = dp.metrics + metric_names = list(metrics.columns) + + if dp.unit_ids is None: + unit_ids = metrics.index.values + else: + unit_ids = dp.unit_ids + # unit_ids = self.make_serializable(unit_ids) + unit_ids = make_serializable(unit_ids) + + metrics_sv = [] + for col in metric_names: + dtype = metrics.iloc[0][col].dtype + metric = vv.UnitMetricsGraphMetric(key=col, label=col, dtype=dtype.str) + metrics_sv.append(metric) + + units_m = [] + for unit_id in unit_ids: + values = check_json(metrics.loc[unit_id].to_dict()) + values_skip_nans = {} + for k, v in values.items(): + if np.isnan(v): + continue + values_skip_nans[k] = v + + units_m.append(vv.UnitMetricsGraphUnit(unit_id=unit_id, values=values_skip_nans)) + v_metrics = vv.UnitMetricsGraph(units=units_m, metrics=metrics_sv) + + if not dp.hide_unit_selector: + if dp.include_metrics_data: + # make a view of the sorting to add tmp properties + sorting_copy = dp.sorting.select_units(unit_ids=dp.sorting.unit_ids) + for col in metric_names: + if col not in sorting_copy.get_property_keys(): + sorting_copy.set_property(col, metrics[col].values) + # generate table with properties + v_units_table = generate_unit_table_view(sorting_copy, unit_properties=metric_names) + else: + v_units_table = generate_unit_table_view(dp.sorting) + + self.view = vv.Splitter( + direction="horizontal", item1=vv.LayoutItem(v_units_table), item2=vv.LayoutItem(v_metrics) + ) + else: + self.view = v_metrics + + # self.handle_display_and_url(view, **backend_kwargs) + # return view + self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) \ No newline at end of file diff --git a/src/spikeinterface/widgets/quality_metrics.py b/src/spikeinterface/widgets/quality_metrics.py index f1c2ad6e23..46bcd6c07b 100644 --- a/src/spikeinterface/widgets/quality_metrics.py +++ b/src/spikeinterface/widgets/quality_metrics.py @@ -23,7 +23,7 @@ class QualityMetricsWidget(MetricsBaseWidget): For sortingview backend, if True the unit selector is not displayed, default False """ - possible_backends = {} + # possible_backends = {} def __init__( self, diff --git a/src/spikeinterface/widgets/template_metrics.py b/src/spikeinterface/widgets/template_metrics.py index b441882730..7361757666 100644 --- a/src/spikeinterface/widgets/template_metrics.py +++ b/src/spikeinterface/widgets/template_metrics.py @@ -22,7 +22,7 @@ class TemplateMetricsWidget(MetricsBaseWidget): For sortingview backend, if True the unit selector is not displayed, default False """ - possible_backends = {} + # possible_backends = {} def __init__( self, diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index 2a146b52b9..e9e2b179b0 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -38,8 +38,8 @@ from .all_amplitudes_distributions import AllAmplitudesDistributionsWidget # metrics -# from .quality_metrics import QualityMetricsWidget -# from .template_metrics import TemplateMetricsWidget +from .quality_metrics import QualityMetricsWidget +from .template_metrics import TemplateMetricsWidget # motion/drift @@ -61,10 +61,10 @@ AllAmplitudesDistributionsWidget, AutoCorrelogramsWidget, # CrossCorrelogramsWidget, - # QualityMetricsWidget, + QualityMetricsWidget, SpikeLocationsWidget, # SpikesOnTracesWidget, - # TemplateMetricsWidget, + TemplateMetricsWidget, # MotionWidget, # TemplateSimilarityWidget, # TimeseriesWidget, @@ -135,4 +135,6 @@ plot_autocorrelograms = AutoCorrelogramsWidget plot_crosscorrelograms = CrossCorrelogramsWidget plot_spike_locations = SpikeLocationsWidget +plot_template_metrics = TemplateMetricsWidget +plot_quality_metrics = QualityMetricsWidget From 5394263962e8f2e6370881af62e22817677be0ce Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 19 Jul 2023 21:01:42 +0200 Subject: [PATCH 10/31] widget refactor MotionWidget --- src/spikeinterface/widgets/motion.py | 128 +++++++++++++++++++++- src/spikeinterface/widgets/widget_list.py | 7 +- 2 files changed, 130 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index 82e9be2407..48aba8de47 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -1,7 +1,7 @@ import numpy as np from warnings import warn -from .base import BaseWidget +from .base import BaseWidget, to_attr from .utils import get_unit_colors @@ -36,7 +36,7 @@ class MotionWidget(BaseWidget): The alpha of the scatter points, default 0.5 """ - possible_backends = {} + # possible_backends = {} def __init__( self, @@ -68,3 +68,127 @@ def __init__( ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .matplotlib_utils import make_mpl_figure + from matplotlib.colors import Normalize + + from spikeinterface.sortingcomponents.motion_interpolation import correct_motion_on_peaks + + + dp = to_attr(data_plot) + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + + + assert backend_kwargs["axes"] is None + assert backend_kwargs["ax"] is None + + # self.make_mpl_figure(**backend_kwargs) + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + fig = self.figure + fig.clear() + + is_rigid = dp.motion.shape[1] == 1 + + gs = fig.add_gridspec(2, 2, wspace=0.3, hspace=0.3) + ax0 = fig.add_subplot(gs[0, 0]) + ax1 = fig.add_subplot(gs[0, 1]) + ax2 = fig.add_subplot(gs[1, 0]) + if not is_rigid: + ax3 = fig.add_subplot(gs[1, 1]) + ax1.sharex(ax0) + ax1.sharey(ax0) + + if dp.motion_lim is None: + motion_lim = np.max(np.abs(dp.motion)) * 1.05 + else: + motion_lim = dp.motion_lim + + if dp.times is None: + temporal_bins_plot = dp.temporal_bins + x = dp.peaks["sample_index"] / dp.sampling_frequency + else: + # use real times and adjust temporal bins with t_start + temporal_bins_plot = dp.temporal_bins + dp.times[0] + x = dp.times[dp.peaks["sample_index"]] + + corrected_location = correct_motion_on_peaks( + dp.peaks, + dp.peak_locations, + dp.sampling_frequency, + dp.motion, + dp.temporal_bins, + dp.spatial_bins, + direction="y", + ) + + y = dp.peak_locations["y"] + y2 = corrected_location["y"] + if dp.scatter_decimate is not None: + x = x[:: dp.scatter_decimate] + y = y[:: dp.scatter_decimate] + y2 = y2[:: dp.scatter_decimate] + + if dp.color_amplitude: + amps = dp.peaks["amplitude"] + amps_abs = np.abs(amps) + q_95 = np.quantile(amps_abs, 0.95) + if dp.scatter_decimate is not None: + amps = amps[:: dp.scatter_decimate] + amps_abs = amps_abs[:: dp.scatter_decimate] + cmap = plt.get_cmap(dp.amplitude_cmap) + if dp.amplitude_clim is None: + amps = amps_abs + amps /= q_95 + c = cmap(amps) + else: + norm_function = Normalize(vmin=dp.amplitude_clim[0], vmax=dp.amplitude_clim[1], clip=True) + c = cmap(norm_function(amps)) + color_kwargs = dict( + color=None, + c=c, + alpha=dp.amplitude_alpha, + ) + else: + color_kwargs = dict(color="k", c=None, alpha=dp.amplitude_alpha) + + ax0.scatter(x, y, s=1, **color_kwargs) + if dp.depth_lim is not None: + ax0.set_ylim(*dp.depth_lim) + ax0.set_title("Peak depth") + ax0.set_xlabel("Times [s]") + ax0.set_ylabel("Depth [um]") + + ax1.scatter(x, y2, s=1, **color_kwargs) + ax1.set_xlabel("Times [s]") + ax1.set_ylabel("Depth [um]") + ax1.set_title("Corrected peak depth") + + ax2.plot(temporal_bins_plot, dp.motion, alpha=0.2, color="black") + ax2.plot(temporal_bins_plot, np.mean(dp.motion, axis=1), color="C0") + ax2.set_ylim(-motion_lim, motion_lim) + ax2.set_ylabel("Motion [um]") + ax2.set_title("Motion vectors") + axes = [ax0, ax1, ax2] + + if not is_rigid: + im = ax3.imshow( + dp.motion.T, + aspect="auto", + origin="lower", + extent=( + temporal_bins_plot[0], + temporal_bins_plot[-1], + dp.spatial_bins[0], + dp.spatial_bins[-1], + ), + ) + im.set_clim(-motion_lim, motion_lim) + cbar = fig.colorbar(im) + cbar.ax.set_xlabel("motion [um]") + ax3.set_xlabel("Times [s]") + ax3.set_ylabel("Depth [um]") + ax3.set_title("Motion vectors") + axes.append(ax3) + self.axes = np.array(axes) \ No newline at end of file diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index e9e2b179b0..897965b4eb 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -43,7 +43,7 @@ # motion/drift -# from .motion import MotionWidget +from .motion import MotionWidget # similarity # from .template_similarity import TemplateSimilarityWidget @@ -60,12 +60,12 @@ AmplitudesWidget, AllAmplitudesDistributionsWidget, AutoCorrelogramsWidget, - # CrossCorrelogramsWidget, + CrossCorrelogramsWidget, QualityMetricsWidget, SpikeLocationsWidget, # SpikesOnTracesWidget, TemplateMetricsWidget, - # MotionWidget, + MotionWidget, # TemplateSimilarityWidget, # TimeseriesWidget, UnitLocationsWidget, @@ -137,4 +137,5 @@ plot_spike_locations = SpikeLocationsWidget plot_template_metrics = TemplateMetricsWidget plot_quality_metrics = QualityMetricsWidget +plot_motion = MotionWidget From ddf0d8d3e417acd652c1784b9cc0092f49fb4670 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 19 Jul 2023 21:06:36 +0200 Subject: [PATCH 11/31] refactor widget : TemplateSimilarityWidget --- .../widgets/template_similarity.py | 56 ++++++++++++++++++- src/spikeinterface/widgets/widget_list.py | 5 +- 2 files changed, 56 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/widgets/template_similarity.py b/src/spikeinterface/widgets/template_similarity.py index 475c873c29..93b9a49f49 100644 --- a/src/spikeinterface/widgets/template_similarity.py +++ b/src/spikeinterface/widgets/template_similarity.py @@ -1,9 +1,8 @@ import numpy as np from typing import Union -from .base import BaseWidget +from .base import BaseWidget, to_attr from ..core.waveform_extractor import WaveformExtractor -from ..core.basesorting import BaseSorting class TemplateSimilarityWidget(BaseWidget): @@ -27,7 +26,7 @@ class TemplateSimilarityWidget(BaseWidget): If True, color bar is displayed, default True. """ - possible_backends = {} + # possible_backends = {} def __init__( self, @@ -63,3 +62,54 @@ def __init__( ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .matplotlib_utils import make_mpl_figure + + dp = to_attr(data_plot) + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + + # self.make_mpl_figure(**backend_kwargs) + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + im = self.ax.matshow(dp.similarity, cmap=dp.cmap) + + if dp.show_unit_ticks: + # Major ticks + self.ax.set_xticks(np.arange(0, len(dp.unit_ids))) + self.ax.set_yticks(np.arange(0, len(dp.unit_ids))) + self.ax.xaxis.tick_bottom() + + # Labels for major ticks + self.ax.set_yticklabels(dp.unit_ids, fontsize=12) + self.ax.set_xticklabels(dp.unit_ids, fontsize=12) + if dp.show_colorbar: + self.figure.colorbar(im) + + def plot_sortingview(self, data_plot, **backend_kwargs): + import sortingview.views as vv + from .sortingview_utils import generate_unit_table_view, make_serializable, handle_display_and_url + + + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + dp = to_attr(data_plot) + + # ensure serializable for sortingview + # unit_ids = self.make_serializable(dp.unit_ids) + unit_ids = make_serializable(dp.unit_ids) + + # similarity + ss_items = [] + for i1, u1 in enumerate(unit_ids): + for i2, u2 in enumerate(unit_ids): + ss_items.append( + vv.UnitSimilarityScore(unit_id1=u1, unit_id2=u2, similarity=dp.similarity[i1, i2].astype("float32")) + ) + + self.view = vv.UnitSimilarityMatrix(unit_ids=list(unit_ids), similarity_scores=ss_items) + + # self.handle_display_and_url(view, **backend_kwargs) + # return view + self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) + diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index 897965b4eb..e2366920e5 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -46,7 +46,7 @@ from .motion import MotionWidget # similarity -# from .template_similarity import TemplateSimilarityWidget +from .template_similarity import TemplateSimilarityWidget # from .unit_depths import UnitDepthsWidget @@ -66,7 +66,7 @@ # SpikesOnTracesWidget, TemplateMetricsWidget, MotionWidget, - # TemplateSimilarityWidget, + TemplateSimilarityWidget, # TimeseriesWidget, UnitLocationsWidget, # UnitTemplatesWidget, @@ -138,4 +138,5 @@ plot_template_metrics = TemplateMetricsWidget plot_quality_metrics = QualityMetricsWidget plot_motion = MotionWidget +plot_template_similarity = TemplateSimilarityWidget From 9890419db8fb8487d04fae471438002f67070a40 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 19 Jul 2023 21:30:21 +0200 Subject: [PATCH 12/31] refactor widgets UnitTemplatesWidget UnitWaveformsWidget --- src/spikeinterface/widgets/unit_templates.py | 53 +++- src/spikeinterface/widgets/unit_waveforms.py | 250 ++++++++++++++++++- src/spikeinterface/widgets/widget_list.py | 10 +- 3 files changed, 305 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/widgets/unit_templates.py b/src/spikeinterface/widgets/unit_templates.py index 41c4ece09c..84856d2df4 100644 --- a/src/spikeinterface/widgets/unit_templates.py +++ b/src/spikeinterface/widgets/unit_templates.py @@ -1,12 +1,61 @@ from .unit_waveforms import UnitWaveformsWidget - +from .base import to_attr class UnitTemplatesWidget(UnitWaveformsWidget): - possible_backends = {} + # possible_backends = {} def __init__(self, *args, **kargs): kargs["plot_waveforms"] = False UnitWaveformsWidget.__init__(self, *args, **kargs) + def plot_sortingview(self, data_plot, **backend_kwargs): + import sortingview.views as vv + from .sortingview_utils import generate_unit_table_view, make_serializable, handle_display_and_url + + dp = to_attr(data_plot) + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + + # ensure serializable for sortingview + unit_id_to_channel_ids = dp.sparsity.unit_id_to_channel_ids + unit_id_to_channel_indices = dp.sparsity.unit_id_to_channel_indices + + # unit_ids, channel_ids = self.make_serializable(dp.unit_ids, dp.channel_ids) + unit_ids, channel_ids = make_serializable(dp.unit_ids, dp.channel_ids) + + templates_dict = {} + for u_i, unit in enumerate(unit_ids): + templates_dict[unit] = {} + templates_dict[unit]["mean"] = dp.templates[u_i].T.astype("float32")[unit_id_to_channel_indices[unit]] + templates_dict[unit]["std"] = dp.template_stds[u_i].T.astype("float32")[unit_id_to_channel_indices[unit]] + + aw_items = [ + vv.AverageWaveformItem( + unit_id=u, + channel_ids=list(unit_id_to_channel_ids[u]), + waveform=t["mean"].astype("float32"), + waveform_std_dev=t["std"].astype("float32"), + ) + for u, t in templates_dict.items() + ] + + locations = {str(ch): dp.channel_locations[i_ch].astype("float32") for i_ch, ch in enumerate(channel_ids)} + v_average_waveforms = vv.AverageWaveforms(average_waveforms=aw_items, channel_locations=locations) + + if not dp.hide_unit_selector: + v_units_table = generate_unit_table_view(dp.waveform_extractor.sorting) + + self.view = vv.Box( + direction="horizontal", + items=[vv.LayoutItem(v_units_table, max_size=150), vv.LayoutItem(v_average_waveforms)], + ) + else: + self.view = v_average_waveforms + + # self.handle_display_and_url(view, **backend_kwargs) + # return view + self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) + + + UnitTemplatesWidget.__doc__ = UnitWaveformsWidget.__doc__ diff --git a/src/spikeinterface/widgets/unit_waveforms.py b/src/spikeinterface/widgets/unit_waveforms.py index ba707a8221..49c75bf046 100644 --- a/src/spikeinterface/widgets/unit_waveforms.py +++ b/src/spikeinterface/widgets/unit_waveforms.py @@ -1,6 +1,6 @@ import numpy as np -from .base import BaseWidget +from .base import BaseWidget, to_attr from .utils import get_unit_colors from ..core import ChannelSparsity @@ -59,7 +59,7 @@ class UnitWaveformsWidget(BaseWidget): Display legend, default True """ - possible_backends = {} + # possible_backends = {} def __init__( self, @@ -165,6 +165,252 @@ def __init__( BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .matplotlib_utils import make_mpl_figure + from probeinterface.plotting import plot_probe + + from matplotlib.patches import Ellipse + from matplotlib.lines import Line2D + + dp = to_attr(data_plot) + + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + + if backend_kwargs.get("axes", None) is not None: + assert len(backend_kwargs["axes"]) >= len(dp.unit_ids), "Provide as many 'axes' as neurons" + elif backend_kwargs.get("ax", None) is not None: + assert dp.same_axis, "If 'same_axis' is not used, provide as many 'axes' as neurons" + else: + if dp.same_axis: + backend_kwargs["num_axes"] = 1 + backend_kwargs["ncols"] = None + else: + backend_kwargs["num_axes"] = len(dp.unit_ids) + backend_kwargs["ncols"] = min(dp.ncols, len(dp.unit_ids)) + + # self.make_mpl_figure(**backend_kwargs) + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + for i, unit_id in enumerate(dp.unit_ids): + if dp.same_axis: + ax = self.ax + else: + ax = self.axes.flatten()[i] + color = dp.unit_colors[unit_id] + + chan_inds = dp.sparsity.unit_id_to_channel_indices[unit_id] + xvectors_flat = dp.xvectors[:, chan_inds].T.flatten() + + # plot waveforms + if dp.plot_waveforms: + wfs = dp.wfs_by_ids[unit_id] + if dp.unit_selected_waveforms is not None: + wfs = wfs[dp.unit_selected_waveforms[unit_id]] + elif dp.max_spikes_per_unit is not None: + if len(wfs) > dp.max_spikes_per_unit: + random_idxs = np.random.permutation(len(wfs))[: dp.max_spikes_per_unit] + wfs = wfs[random_idxs] + wfs = wfs * dp.y_scale + dp.y_offset[None, :, chan_inds] + wfs_flat = wfs.swapaxes(1, 2).reshape(wfs.shape[0], -1).T + + if dp.x_offset_units: + # 0.7 is to match spacing in xvect + xvec = xvectors_flat + i * 0.7 * dp.delta_x + else: + xvec = xvectors_flat + + ax.plot(xvec, wfs_flat, lw=dp.lw_waveforms, alpha=dp.alpha_waveforms, color=color) + + if not dp.plot_templates: + ax.get_lines()[-1].set_label(f"{unit_id}") + + # plot template + if dp.plot_templates: + template = dp.templates[i, :, :][:, chan_inds] * dp.y_scale + dp.y_offset[:, chan_inds] + + if dp.x_offset_units: + # 0.7 is to match spacing in xvect + xvec = xvectors_flat + i * 0.7 * dp.delta_x + else: + xvec = xvectors_flat + + ax.plot( + xvec, template.T.flatten(), lw=dp.lw_templates, alpha=dp.alpha_templates, color=color, label=unit_id + ) + + template_label = dp.unit_ids[i] + if dp.set_title: + ax.set_title(f"template {template_label}") + + # plot channels + if dp.plot_channels: + # TODO enhance this + ax.scatter(dp.channel_locations[:, 0], dp.channel_locations[:, 1], color="k") + + if dp.same_axis and dp.plot_legend: + # if self.legend is not None: + if hasattr(self, 'legend') and self.legend is not None: + self.legend.remove() + self.legend = self.figure.legend( + loc="upper center", bbox_to_anchor=(0.5, 1.0), ncol=5, fancybox=True, shadow=True + ) + + def plot_ipywidgets(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + import ipywidgets.widgets as widgets + from IPython.display import display + from .ipywidgets_utils import check_ipywidget_backend, make_unit_controller + + check_ipywidget_backend() + + self.next_data_plot = data_plot.copy() + + cm = 1 / 2.54 + self.we = we = data_plot["waveform_extractor"] + + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + width_cm = backend_kwargs["width_cm"] + height_cm = backend_kwargs["height_cm"] + + ratios = [0.1, 0.7, 0.2] + + with plt.ioff(): + output1 = widgets.Output() + with output1: + self.fig_wf = plt.figure(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) + plt.show() + output2 = widgets.Output() + with output2: + self.fig_probe, self.ax_probe = plt.subplots(figsize=((ratios[2] * width_cm) * cm, height_cm * cm)) + plt.show() + + data_plot["unit_ids"] = data_plot["unit_ids"][:1] + unit_widget, unit_controller = make_unit_controller( + data_plot["unit_ids"], we.unit_ids, ratios[0] * width_cm, height_cm + ) + + same_axis_button = widgets.Checkbox( + value=False, + description="same axis", + disabled=False, + ) + + plot_templates_button = widgets.Checkbox( + value=True, + description="plot templates", + disabled=False, + ) + + hide_axis_button = widgets.Checkbox( + value=True, + description="hide axis", + disabled=False, + ) + + footer = widgets.HBox([same_axis_button, plot_templates_button, hide_axis_button]) + + self.controller = { + "same_axis": same_axis_button, + "plot_templates": plot_templates_button, + "hide_axis": hide_axis_button, + } + self.controller.update(unit_controller) + + # mpl_plotter = MplUnitWaveformPlotter() + + # self.updater = PlotUpdater(data_plot, mpl_plotter, fig_wf, ax_probe, self.controller) + # for w in self.controller.values(): + # w.observe(self.updater) + + for w in self.controller.values(): + w.observe(self._update_ipywidget) + + + self.widget = widgets.AppLayout( + center=self.fig_wf.canvas, + left_sidebar=unit_widget, + right_sidebar=self.fig_probe.canvas, + pane_widths=ratios, + footer=footer, + ) + + # a first update + # self.updater(None) + self._update_ipywidget(None) + + if backend_kwargs["display"]: + # self.check_backend() + display(self.widget) + + def _update_ipywidget(self, change): + self.fig_wf.clear() + self.ax_probe.clear() + + unit_ids = self.controller["unit_ids"].value + same_axis = self.controller["same_axis"].value + plot_templates = self.controller["plot_templates"].value + hide_axis = self.controller["hide_axis"].value + + # matplotlib next_data_plot dict update at each call + data_plot = self.next_data_plot + data_plot["unit_ids"] = unit_ids + data_plot["templates"] = self.we.get_all_templates(unit_ids=unit_ids) + data_plot["template_stds"] = self.we.get_all_templates(unit_ids=unit_ids, mode="std") + data_plot["same_axis"] = same_axis + data_plot["plot_templates"] = plot_templates + if data_plot["plot_waveforms"]: + data_plot["wfs_by_ids"] = {unit_id: self.we.get_waveforms(unit_id) for unit_id in unit_ids} + + backend_kwargs = {} + + if same_axis: + backend_kwargs["ax"] = self.fig_wf.add_subplot() + data_plot["set_title"] = False + else: + backend_kwargs["figure"] = self.fig_wf + + # self.mpl_plotter.do_plot(data_plot, **backend_kwargs) + self.plot_matplotlib(data_plot, **backend_kwargs) + if same_axis: + # self.mpl_plotter.ax.axis("equal") + self.ax.axis("equal") + if hide_axis: + # self.mpl_plotter.ax.axis("off") + self.ax.axis("off") + else: + if hide_axis: + for i in range(len(unit_ids)): + # ax = self.mpl_plotter.axes.flatten()[i] + ax = self.axes.flatten()[i] + ax.axis("off") + + # update probe plot + channel_locations = self.we.get_channel_locations() + self.ax_probe.plot( + channel_locations[:, 0], channel_locations[:, 1], ls="", marker="o", color="gray", markersize=2, alpha=0.5 + ) + self.ax_probe.axis("off") + self.ax_probe.axis("equal") + + for unit in unit_ids: + channel_inds = data_plot["sparsity"].unit_id_to_channel_indices[unit] + self.ax_probe.plot( + channel_locations[channel_inds, 0], + channel_locations[channel_inds, 1], + ls="", + marker="o", + markersize=3, + color=self.next_data_plot["unit_colors"][unit], + ) + self.ax_probe.set_xlim(np.min(channel_locations[:, 0]) - 10, np.max(channel_locations[:, 0]) + 10) + fig_probe = self.ax_probe.get_figure() + + self.fig_wf.canvas.draw() + self.fig_wf.canvas.flush_events() + fig_probe.canvas.draw() + fig_probe.canvas.flush_events() + def get_waveforms_scales(we, templates, channel_locations, x_offset_units=False): """ diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index e2366920e5..cb19eda93c 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -5,8 +5,8 @@ # from .timeseries import TimeseriesWidget # waveform -# from .unit_waveforms import UnitWaveformsWidget -# from .unit_templates import UnitTemplatesWidget +from .unit_waveforms import UnitWaveformsWidget +from .unit_templates import UnitTemplatesWidget # from .unit_waveforms_density_map import UnitWaveformDensityMapWidget # isi/ccg/acg @@ -69,8 +69,8 @@ TemplateSimilarityWidget, # TimeseriesWidget, UnitLocationsWidget, - # UnitTemplatesWidget, - # UnitWaveformsWidget, + UnitTemplatesWidget, + UnitWaveformsWidget, # UnitWaveformDensityMapWidget, # UnitDepthsWidget, # summary @@ -139,4 +139,6 @@ plot_quality_metrics = QualityMetricsWidget plot_motion = MotionWidget plot_template_similarity = TemplateSimilarityWidget +plot_unit_templates = UnitTemplatesWidget +plot_unit_waveforms = UnitWaveformsWidget From f064513b1631697b7db83197d8113852edd592e8 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 19 Jul 2023 21:36:39 +0200 Subject: [PATCH 13/31] widget refactor : UnitWaveformDensityMapWidget --- .../widgets/unit_waveforms_density_map.py | 76 ++++++++++++++++++- src/spikeinterface/widgets/widget_list.py | 5 +- 2 files changed, 77 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/widgets/unit_waveforms_density_map.py b/src/spikeinterface/widgets/unit_waveforms_density_map.py index 9f3e5e86b5..9216373d87 100644 --- a/src/spikeinterface/widgets/unit_waveforms_density_map.py +++ b/src/spikeinterface/widgets/unit_waveforms_density_map.py @@ -1,6 +1,6 @@ import numpy as np -from .base import BaseWidget +from .base import BaseWidget, to_attr from .utils import get_unit_colors from ..core import ChannelSparsity, get_template_extremum_channel @@ -33,7 +33,7 @@ class UnitWaveformDensityMapWidget(BaseWidget): all channel per units, default False """ - possible_backends = {} + # possible_backends = {} def __init__( self, @@ -156,3 +156,75 @@ def __init__( ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .matplotlib_utils import make_mpl_figure + + dp = to_attr(data_plot) + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + + if backend_kwargs["axes"] is not None or backend_kwargs["ax"] is not None: + # self.make_mpl_figure(**backend_kwargs) + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + else: + if dp.same_axis: + num_axes = 1 + else: + num_axes = len(dp.unit_ids) + backend_kwargs["ncols"] = 1 + backend_kwargs["num_axes"] = num_axes + # self.make_mpl_figure(**backend_kwargs) + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + if dp.same_axis: + ax = self.ax + hist2d = dp.all_hist2d + im = ax.imshow( + hist2d.T, + interpolation="nearest", + origin="lower", + aspect="auto", + extent=(0, hist2d.shape[0], dp.bin_min, dp.bin_max), + cmap="hot", + ) + else: + for unit_index, unit_id in enumerate(dp.unit_ids): + hist2d = dp.all_hist2d[unit_id] + ax = self.axes.flatten()[unit_index] + im = ax.imshow( + hist2d.T, + interpolation="nearest", + origin="lower", + aspect="auto", + extent=(0, hist2d.shape[0], dp.bin_min, dp.bin_max), + cmap="hot", + ) + + for unit_index, unit_id in enumerate(dp.unit_ids): + if dp.same_axis: + ax = self.ax + else: + ax = self.axes.flatten()[unit_index] + color = dp.unit_colors[unit_id] + ax.plot(dp.templates_flat[unit_id], color=color, lw=1) + + # final cosmetics + for unit_index, unit_id in enumerate(dp.unit_ids): + if dp.same_axis: + ax = self.ax + if unit_index != 0: + continue + else: + ax = self.axes.flatten()[unit_index] + chan_inds = dp.channel_inds[unit_id] + for i, chan_ind in enumerate(chan_inds): + if i != 0: + ax.axvline(i * dp.template_width, color="w", lw=3) + channel_id = dp.channel_ids[chan_ind] + x = i * dp.template_width + dp.template_width // 2 + y = (dp.bin_max + dp.bin_min) / 2.0 + ax.text(x, y, f"chan_id {channel_id}", color="w", ha="center", va="center") + + ax.set_xticks([]) + ax.set_ylabel(f"unit_id {unit_id}") diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index cb19eda93c..68034ee27e 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -7,7 +7,7 @@ # waveform from .unit_waveforms import UnitWaveformsWidget from .unit_templates import UnitTemplatesWidget -# from .unit_waveforms_density_map import UnitWaveformDensityMapWidget +from .unit_waveforms_density_map import UnitWaveformDensityMapWidget # isi/ccg/acg from .autocorrelograms import AutoCorrelogramsWidget @@ -71,7 +71,7 @@ UnitLocationsWidget, UnitTemplatesWidget, UnitWaveformsWidget, - # UnitWaveformDensityMapWidget, + UnitWaveformDensityMapWidget, # UnitDepthsWidget, # summary # UnitSummaryWidget, @@ -141,4 +141,5 @@ plot_template_similarity = TemplateSimilarityWidget plot_unit_templates = UnitTemplatesWidget plot_unit_waveforms = UnitWaveformsWidget +plot_unit_waveforms_density_map = UnitWaveformDensityMapWidget From d9307ab24a96dad6dbfd2a72b6f68615a79a1d15 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 19 Jul 2023 22:32:29 +0200 Subject: [PATCH 14/31] refactor widget : UnitDepthsWidget --- src/spikeinterface/widgets/unit_depths.py | 23 +++++++++++++++++++++-- src/spikeinterface/widgets/widget_list.py | 5 +++-- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/widgets/unit_depths.py b/src/spikeinterface/widgets/unit_depths.py index 5ceee0c133..9b710815e4 100644 --- a/src/spikeinterface/widgets/unit_depths.py +++ b/src/spikeinterface/widgets/unit_depths.py @@ -1,7 +1,7 @@ import numpy as np from warnings import warn -from .base import BaseWidget +from .base import BaseWidget, to_attr from .utils import get_unit_colors @@ -24,7 +24,7 @@ class UnitDepthsWidget(BaseWidget): Sign of peak for amplitudes, default 'neg' """ - possible_backends = {} + # possible_backends = {} def __init__( self, waveform_extractor, unit_colors=None, depth_axis=1, peak_sign="neg", backend=None, **backend_kwargs @@ -56,3 +56,22 @@ def __init__( ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .matplotlib_utils import make_mpl_figure + + dp = to_attr(data_plot) + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + # self.make_mpl_figure(**backend_kwargs) + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + ax = self.ax + size = dp.num_spikes / max(dp.num_spikes) * 120 + ax.scatter(dp.unit_amplitudes, dp.unit_depths, color=dp.colors, s=size) + + ax.set_aspect(3) + ax.set_xlabel("amplitude") + ax.set_ylabel("depth [um]") + ax.set_xlim(0, max(dp.unit_amplitudes) * 1.2) + diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index 68034ee27e..4ded22305e 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -49,7 +49,7 @@ from .template_similarity import TemplateSimilarityWidget -# from .unit_depths import UnitDepthsWidget +from .unit_depths import UnitDepthsWidget # summary # from .unit_summary import UnitSummaryWidget @@ -72,7 +72,7 @@ UnitTemplatesWidget, UnitWaveformsWidget, UnitWaveformDensityMapWidget, - # UnitDepthsWidget, + UnitDepthsWidget, # summary # UnitSummaryWidget, # SortingSummaryWidget, @@ -142,4 +142,5 @@ plot_unit_templates = UnitTemplatesWidget plot_unit_waveforms = UnitWaveformsWidget plot_unit_waveforms_density_map = UnitWaveformDensityMapWidget +plot_unit_depths = UnitDepthsWidget From 8da772269f8ff85244431f7ca16d41172eec27f7 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 19 Jul 2023 22:55:33 +0200 Subject: [PATCH 15/31] refactor widget : UnitSummaryWidget --- src/spikeinterface/widgets/unit_summary.py | 189 ++++++++++++++++----- src/spikeinterface/widgets/widget_list.py | 5 +- 2 files changed, 150 insertions(+), 44 deletions(-) diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index 8e1ffe2637..68fa8b77d2 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -1,7 +1,7 @@ import numpy as np from typing import Union -from .base import BaseWidget +from .base import BaseWidget, to_attr from .utils import get_unit_colors @@ -31,7 +31,7 @@ class UnitSummaryWidget(BaseWidget): If WaveformExtractor is already sparse, the argument is ignored """ - possible_backends = {} + # possible_backends = {} def __init__( self, @@ -48,55 +48,160 @@ def __init__( if unit_colors is None: unit_colors = get_unit_colors(we.sorting) - if we.is_extension("unit_locations"): - plot_data_unit_locations = UnitLocationsWidget( - we, unit_ids=[unit_id], unit_colors=unit_colors, plot_legend=False - ).plot_data - unit_locations = waveform_extractor.load_extension("unit_locations").get_data(outputs="by_unit") - unit_location = unit_locations[unit_id] - else: - plot_data_unit_locations = None - unit_location = None + # if we.is_extension("unit_locations"): + # plot_data_unit_locations = UnitLocationsWidget( + # we, unit_ids=[unit_id], unit_colors=unit_colors, plot_legend=False + # ).plot_data + # unit_locations = waveform_extractor.load_extension("unit_locations").get_data(outputs="by_unit") + # unit_location = unit_locations[unit_id] + # else: + # plot_data_unit_locations = None + # unit_location = None + + # plot_data_waveforms = UnitWaveformsWidget( + # we, + # unit_ids=[unit_id], + # unit_colors=unit_colors, + # plot_templates=True, + # same_axis=True, + # plot_legend=False, + # sparsity=sparsity, + # ).plot_data + + # plot_data_waveform_density = UnitWaveformDensityMapWidget( + # we, unit_ids=[unit_id], unit_colors=unit_colors, use_max_channel=True, plot_templates=True, same_axis=False + # ).plot_data + + # if we.is_extension("correlograms"): + # plot_data_acc = AutoCorrelogramsWidget( + # we, + # unit_ids=[unit_id], + # unit_colors=unit_colors, + # ).plot_data + # else: + # plot_data_acc = None + + # use other widget to plot data + # if we.is_extension("spike_amplitudes"): + # plot_data_amplitudes = AmplitudesWidget( + # we, unit_ids=[unit_id], unit_colors=unit_colors, plot_legend=False, plot_histograms=True + # ).plot_data + # else: + # plot_data_amplitudes = None - plot_data_waveforms = UnitWaveformsWidget( - we, - unit_ids=[unit_id], + plot_data = dict( + we=we, + unit_id=unit_id, unit_colors=unit_colors, - plot_templates=True, - same_axis=True, - plot_legend=False, sparsity=sparsity, - ).plot_data + # unit_location=unit_location, + # plot_data_unit_locations=plot_data_unit_locations, + # plot_data_waveforms=plot_data_waveforms, + # plot_data_waveform_density=plot_data_waveform_density, + # plot_data_acc=plot_data_acc, + # plot_data_amplitudes=plot_data_amplitudes, + ) + + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) - plot_data_waveform_density = UnitWaveformDensityMapWidget( - we, unit_ids=[unit_id], unit_colors=unit_colors, use_max_channel=True, plot_templates=True, same_axis=False - ).plot_data + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .matplotlib_utils import make_mpl_figure + + dp = to_attr(data_plot) + + unit_id = dp.unit_id + we = dp.we + unit_colors = dp.unit_colors + sparsity = dp.sparsity + + + # force the figure without axes + if "figsize" not in backend_kwargs: + backend_kwargs["figsize"] = (18, 7) + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + backend_kwargs["num_axes"] = 0 + backend_kwargs["ax"] = None + backend_kwargs["axes"] = None + + # self.make_mpl_figure(**backend_kwargs) + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + # and use custum grid spec + fig = self.figure + nrows = 2 + ncols = 3 + # if dp.plot_data_acc is not None or dp.plot_data_amplitudes is not None: + if we.is_extension("correlograms") or we.is_extension("spike_amplitudes"): + ncols += 1 + # if dp.plot_data_amplitudes is not None : + if we.is_extension("spike_amplitudes"): + + nrows += 1 + gs = fig.add_gridspec(nrows, ncols) + # if dp.plot_data_unit_locations is not None: + if we.is_extension("unit_locations"): + ax1 = fig.add_subplot(gs[:2, 0]) + # UnitLocationsPlotter().do_plot(dp.plot_data_unit_locations, ax=ax1) + w = UnitLocationsWidget( + we, unit_ids=[unit_id], unit_colors=unit_colors, plot_legend=False, + backend='matplotlib', ax=ax1) + + unit_locations = we.load_extension("unit_locations").get_data(outputs="by_unit") + unit_location = unit_locations[unit_id] + # x, y = dp.unit_location[0], dp.unit_location[1] + x, y = unit_location[0], unit_location[1] + ax1.set_xlim(x - 80, x + 80) + ax1.set_ylim(y - 250, y + 250) + ax1.set_xticks([]) + ax1.set_xlabel(None) + ax1.set_ylabel(None) + + ax2 = fig.add_subplot(gs[:2, 1]) + # UnitWaveformPlotter().do_plot(dp.plot_data_waveforms, ax=ax2) + w = UnitWaveformsWidget( + we, + unit_ids=[unit_id], + unit_colors=unit_colors, + plot_templates=True, + same_axis=True, + plot_legend=False, + sparsity=sparsity, + backend='matplotlib', ax=ax2) + + ax2.set_title(None) + + ax3 = fig.add_subplot(gs[:2, 2]) + # UnitWaveformDensityMapPlotter().do_plot(dp.plot_data_waveform_density, ax=ax3) + UnitWaveformDensityMapWidget( + we, unit_ids=[unit_id], unit_colors=unit_colors, use_max_channel=True, same_axis=False, + backend='matplotlib', ax=ax3) + ax3.set_ylabel(None) + + # if dp.plot_data_acc is not None: if we.is_extension("correlograms"): - plot_data_acc = AutoCorrelogramsWidget( + ax4 = fig.add_subplot(gs[:2, 3]) + # AutoCorrelogramsPlotter().do_plot(dp.plot_data_acc, ax=ax4) + AutoCorrelogramsWidget( we, unit_ids=[unit_id], unit_colors=unit_colors, - ).plot_data - else: - plot_data_acc = None + backend='matplotlib', ax=ax4, + ) - # use other widget to plot data - if we.is_extension("spike_amplitudes"): - plot_data_amplitudes = AmplitudesWidget( - we, unit_ids=[unit_id], unit_colors=unit_colors, plot_legend=False, plot_histograms=True - ).plot_data - else: - plot_data_amplitudes = None - plot_data = dict( - unit_id=unit_id, - unit_location=unit_location, - plot_data_unit_locations=plot_data_unit_locations, - plot_data_waveforms=plot_data_waveforms, - plot_data_waveform_density=plot_data_waveform_density, - plot_data_acc=plot_data_acc, - plot_data_amplitudes=plot_data_amplitudes, - ) + ax4.set_title(None) + ax4.set_yticks([]) - BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + # if dp.plot_data_amplitudes is not None: + if we.is_extension("spike_amplitudes"): + ax5 = fig.add_subplot(gs[2, :3]) + ax6 = fig.add_subplot(gs[2, 3]) + axes = np.array([ax5, ax6]) + # AmplitudesPlotter().do_plot(dp.plot_data_amplitudes, axes=axes) + AmplitudesWidget( + we, unit_ids=[unit_id], unit_colors=unit_colors, plot_legend=False, plot_histograms=True, + backend='matplotlib', axes=axes) + + fig.suptitle(f"unit_id: {dp.unit_id}") diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index 4ded22305e..5820477dc8 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -52,7 +52,7 @@ from .unit_depths import UnitDepthsWidget # summary -# from .unit_summary import UnitSummaryWidget +from .unit_summary import UnitSummaryWidget # from .sorting_summary import SortingSummaryWidget @@ -74,7 +74,7 @@ UnitWaveformDensityMapWidget, UnitDepthsWidget, # summary - # UnitSummaryWidget, + UnitSummaryWidget, # SortingSummaryWidget, ] @@ -143,4 +143,5 @@ plot_unit_waveforms = UnitWaveformsWidget plot_unit_waveforms_density_map = UnitWaveformDensityMapWidget plot_unit_depths = UnitDepthsWidget +plot_unit_summary = UnitSummaryWidget From fa49471061712fee17d9475a1947d3f2d3e6d607 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 19 Jul 2023 23:17:43 +0200 Subject: [PATCH 16/31] refactor widget : SortingSummaryWidget --- src/spikeinterface/widgets/sorting_summary.py | 135 +++++++++++++++--- src/spikeinterface/widgets/widget_list.py | 5 +- 2 files changed, 122 insertions(+), 18 deletions(-) diff --git a/src/spikeinterface/widgets/sorting_summary.py b/src/spikeinterface/widgets/sorting_summary.py index 8f50eb1dde..bdf692888f 100644 --- a/src/spikeinterface/widgets/sorting_summary.py +++ b/src/spikeinterface/widgets/sorting_summary.py @@ -1,6 +1,6 @@ import numpy as np -from .base import BaseWidget, define_widget_function_from_class +from .base import BaseWidget, to_attr from .amplitudes import AmplitudesWidget from .crosscorrelograms import CrossCorrelogramsWidget @@ -34,7 +34,7 @@ class SortingSummaryWidget(BaseWidget): (sortingview backend) """ - possible_backends = {} + # possible_backends = {} def __init__( self, @@ -56,27 +56,130 @@ def __init__( unit_ids = sorting.get_unit_ids() # use other widgets to generate data (except for similarity) - template_plot_data = UnitTemplatesWidget( - we, unit_ids=unit_ids, sparsity=sparsity, hide_unit_selector=True - ).plot_data - ccg_plot_data = CrossCorrelogramsWidget(we, unit_ids=unit_ids, hide_unit_selector=True).plot_data - amps_plot_data = AmplitudesWidget( - we, unit_ids=unit_ids, max_spikes_per_unit=max_amplitudes_per_unit, hide_unit_selector=True - ).plot_data - locs_plot_data = UnitLocationsWidget(we, unit_ids=unit_ids, hide_unit_selector=True).plot_data - sim_plot_data = TemplateSimilarityWidget(we, unit_ids=unit_ids).plot_data + # template_plot_data = UnitTemplatesWidget( + # we, unit_ids=unit_ids, sparsity=sparsity, hide_unit_selector=True + # ).plot_data + # ccg_plot_data = CrossCorrelogramsWidget(we, unit_ids=unit_ids, hide_unit_selector=True).plot_data + # amps_plot_data = AmplitudesWidget( + # we, unit_ids=unit_ids, max_spikes_per_unit=max_amplitudes_per_unit, hide_unit_selector=True + # ).plot_data + # locs_plot_data = UnitLocationsWidget(we, unit_ids=unit_ids, hide_unit_selector=True).plot_data + # sim_plot_data = TemplateSimilarityWidget(we, unit_ids=unit_ids).plot_data plot_data = dict( waveform_extractor=waveform_extractor, unit_ids=unit_ids, - templates=template_plot_data, - correlograms=ccg_plot_data, - amplitudes=amps_plot_data, - similarity=sim_plot_data, - unit_locations=locs_plot_data, + sparsity=sparsity, + # templates=template_plot_data, + # correlograms=ccg_plot_data, + # amplitudes=amps_plot_data, + # similarity=sim_plot_data, + # unit_locations=locs_plot_data, unit_table_properties=unit_table_properties, curation=curation, label_choices=label_choices, + + max_amplitudes_per_unit=max_amplitudes_per_unit, ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_sortingview(self, data_plot, **backend_kwargs): + import sortingview.views as vv + from .sortingview_utils import generate_unit_table_view, make_serializable, handle_display_and_url + + dp = to_attr(data_plot) + we = dp.waveform_extractor + unit_ids = dp.unit_ids + sparsity = dp.sparsity + + + # unit_ids = self.make_serializable(dp.unit_ids) + unit_ids = make_serializable(dp.unit_ids) + + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + + # amplitudes_plotter = AmplitudesPlotter() + # v_spike_amplitudes = amplitudes_plotter.do_plot( + # dp.amplitudes, generate_url=False, display=False, backend="sortingview" + # ) + # template_plotter = UnitTemplatesPlotter() + # v_average_waveforms = template_plotter.do_plot( + # dp.templates, generate_url=False, display=False, backend="sortingview" + # ) + # xcorrelograms_plotter = CrossCorrelogramsPlotter() + # v_cross_correlograms = xcorrelograms_plotter.do_plot( + # dp.correlograms, generate_url=False, display=False, backend="sortingview" + # ) + # unitlocation_plotter = UnitLocationsPlotter() + # v_unit_locations = unitlocation_plotter.do_plot( + # dp.unit_locations, generate_url=False, display=False, backend="sortingview" + # ) + + v_spike_amplitudes = AmplitudesWidget( + we, unit_ids=unit_ids, max_spikes_per_unit=dp.max_amplitudes_per_unit, hide_unit_selector=True, + generate_url=False, display=False, backend="sortingview" + ).view + v_average_waveforms = UnitTemplatesWidget( + we, unit_ids=unit_ids, sparsity=sparsity, hide_unit_selector=True, + generate_url=False, display=False, backend="sortingview" + ).view + v_cross_correlograms = CrossCorrelogramsWidget(we, unit_ids=unit_ids, hide_unit_selector=True, + generate_url=False, display=False, backend="sortingview").view + + v_unit_locations = UnitLocationsWidget(we, unit_ids=unit_ids, hide_unit_selector=True, + generate_url=False, display=False, backend="sortingview").view + + w = TemplateSimilarityWidget(we, unit_ids=unit_ids, immediate_plot=False, + generate_url=False, display=False, backend="sortingview" ) + similarity = w.data_plot["similarity"] + print(similarity.shape) + + # similarity + similarity_scores = [] + for i1, u1 in enumerate(unit_ids): + for i2, u2 in enumerate(unit_ids): + similarity_scores.append( + vv.UnitSimilarityScore( + unit_id1=u1, unit_id2=u2, similarity=similarity[i1, i2].astype("float32") + ) + ) + + # unit ids + v_units_table = generate_unit_table_view( + dp.waveform_extractor.sorting, dp.unit_table_properties, similarity_scores=similarity_scores + ) + + if dp.curation: + v_curation = vv.SortingCuration2(label_choices=dp.label_choices) + v1 = vv.Splitter(direction="vertical", item1=vv.LayoutItem(v_units_table), item2=vv.LayoutItem(v_curation)) + else: + v1 = v_units_table + v2 = vv.Splitter( + direction="horizontal", + item1=vv.LayoutItem(v_unit_locations, stretch=0.2), + item2=vv.LayoutItem( + vv.Splitter( + direction="horizontal", + item1=vv.LayoutItem(v_average_waveforms), + item2=vv.LayoutItem( + vv.Splitter( + direction="vertical", + item1=vv.LayoutItem(v_spike_amplitudes), + item2=vv.LayoutItem(v_cross_correlograms), + ) + ), + ) + ), + ) + + # assemble layout + # v_summary = vv.Splitter(direction="horizontal", item1=vv.LayoutItem(v1), item2=vv.LayoutItem(v2)) + self.view = vv.Splitter(direction="horizontal", item1=vv.LayoutItem(v1), item2=vv.LayoutItem(v2)) + + # self.handle_display_and_url(v_summary, **backend_kwargs) + # return v_summary + + self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) + + diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index 5820477dc8..ae0b898035 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -53,7 +53,7 @@ # summary from .unit_summary import UnitSummaryWidget -# from .sorting_summary import SortingSummaryWidget +from .sorting_summary import SortingSummaryWidget widget_list = [ @@ -75,7 +75,7 @@ UnitDepthsWidget, # summary UnitSummaryWidget, - # SortingSummaryWidget, + SortingSummaryWidget, ] @@ -144,4 +144,5 @@ plot_unit_waveforms_density_map = UnitWaveformDensityMapWidget plot_unit_depths = UnitDepthsWidget plot_unit_summary = UnitSummaryWidget +plot_sorting_summary = SortingSummaryWidget From 9f9587cf1375155e6ff45b98def20b54ad656b8d Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 20 Jul 2023 08:43:16 +0200 Subject: [PATCH 17/31] refactor widgets : TimeseriesWidget --- src/spikeinterface/widgets/timeseries.py | 342 +++++++++++++++++++++- src/spikeinterface/widgets/widget_list.py | 5 +- 2 files changed, 342 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/widgets/timeseries.py b/src/spikeinterface/widgets/timeseries.py index 93e0358460..0e82c85b94 100644 --- a/src/spikeinterface/widgets/timeseries.py +++ b/src/spikeinterface/widgets/timeseries.py @@ -1,8 +1,10 @@ +import warnings + import numpy as np from ..core import BaseRecording, order_channels_by_depth -from .base import BaseWidget -from .utils import get_some_colors +from .base import BaseWidget, to_attr +from .utils import get_some_colors, array_to_image class TimeseriesWidget(BaseWidget): @@ -56,7 +58,7 @@ class TimeseriesWidget(BaseWidget): The output widget """ - possible_backends = {} + # possible_backends = {} def __init__( self, @@ -213,6 +215,340 @@ def __init__( BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from matplotlib.ticker import MaxNLocator + from .matplotlib_utils import make_mpl_figure + + dp = to_attr(data_plot) + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + + # self.make_mpl_figure(**backend_kwargs) + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + ax = self.ax + n = len(dp.channel_ids) + if dp.channel_locations is not None: + y_locs = dp.channel_locations[:, 1] + else: + y_locs = np.arange(n) + min_y = np.min(y_locs) + max_y = np.max(y_locs) + + if dp.mode == "line": + offset = dp.vspacing * (n - 1) + + for layer_key, traces in zip(dp.layer_keys, dp.list_traces): + for i, chan_id in enumerate(dp.channel_ids): + offset = dp.vspacing * i + color = dp.colors[layer_key][chan_id] + ax.plot(dp.times, offset + traces[:, i], color=color) + ax.get_lines()[-1].set_label(layer_key) + + if dp.show_channel_ids: + ax.set_yticks(np.arange(n) * dp.vspacing) + channel_labels = np.array([str(chan_id) for chan_id in dp.channel_ids]) + ax.set_yticklabels(channel_labels) + else: + ax.get_yaxis().set_visible(False) + + ax.set_xlim(*dp.time_range) + ax.set_ylim(-dp.vspacing, dp.vspacing * n) + ax.get_xaxis().set_major_locator(MaxNLocator(prune="both")) + ax.set_xlabel("time (s)") + if dp.add_legend: + ax.legend(loc="upper right") + + elif dp.mode == "map": + assert len(dp.list_traces) == 1, 'plot_timeseries with mode="map" do not support multi recording' + assert len(dp.clims) == 1 + clim = list(dp.clims.values())[0] + extent = (dp.time_range[0], dp.time_range[1], min_y, max_y) + im = ax.imshow( + dp.list_traces[0].T, interpolation="nearest", origin="lower", aspect="auto", extent=extent, cmap=dp.cmap + ) + + im.set_clim(*clim) + + if dp.with_colorbar: + self.figure.colorbar(im, ax=ax) + + if dp.show_channel_ids: + ax.set_yticks(np.linspace(min_y, max_y, n) + (max_y - min_y) / n * 0.5) + channel_labels = np.array([str(chan_id) for chan_id in dp.channel_ids]) + ax.set_yticklabels(channel_labels) + else: + ax.get_yaxis().set_visible(False) + + def plot_ipywidgets(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + import ipywidgets.widgets as widgets + from IPython.display import display + from .ipywidgets_utils import check_ipywidget_backend, make_timeseries_controller, make_channel_controller, make_scale_controller + + check_ipywidget_backend() + + self.next_data_plot = data_plot.copy() + + recordings = data_plot["recordings"] + + # first layer + rec0 = recordings[data_plot["layer_keys"][0]] + + cm = 1 / 2.54 + + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + width_cm = backend_kwargs["width_cm"] + height_cm = backend_kwargs["height_cm"] + ratios = [0.1, 0.8, 0.2] + + with plt.ioff(): + output = widgets.Output() + with output: + self.figure, self.ax = plt.subplots(figsize=(0.9 * ratios[1] * width_cm * cm, height_cm * cm)) + plt.show() + + t_start = 0.0 + t_stop = rec0.get_num_samples(segment_index=0) / rec0.get_sampling_frequency() + + ts_widget, ts_controller = make_timeseries_controller( + t_start, + t_stop, + data_plot["layer_keys"], + rec0.get_num_segments(), + data_plot["time_range"], + data_plot["mode"], + False, + width_cm, + ) + + ch_widget, ch_controller = make_channel_controller(rec0, width_cm=ratios[2] * width_cm, height_cm=height_cm) + + scale_widget, scale_controller = make_scale_controller(width_cm=ratios[0] * width_cm, height_cm=height_cm) + + self.controller = ts_controller + self.controller.update(ch_controller) + self.controller.update(scale_controller) + + # mpl_plotter = MplTimeseriesPlotter() + + # self.updater = PlotUpdater(data_plot, mpl_plotter, ax, self.controller) + # for w in self.controller.values(): + # if isinstance(w, widgets.Button): + # w.on_click(self.updater) + # else: + # w.observe(self.updater) + + self.recordings = data_plot["recordings"] + self.return_scaled = data_plot["return_scaled"] + self.list_traces = None + self.actual_segment_index = self.controller["segment_index"].value + + self.rec0 = self.recordings[self.data_plot["layer_keys"][0]] + self.t_stops = [ + self.rec0.get_num_samples(segment_index=seg_index) / self.rec0.get_sampling_frequency() + for seg_index in range(self.rec0.get_num_segments()) + ] + + for w in self.controller.values(): + if isinstance(w, widgets.Button): + w.on_click(self._update_ipywidget) + else: + w.observe(self._update_ipywidget) + + self.widget = widgets.AppLayout( + center=self.figure.canvas, + footer=ts_widget, + left_sidebar=scale_widget, + right_sidebar=ch_widget, + pane_heights=[0, 6, 1], + pane_widths=ratios, + ) + + # a first update + # self.updater(None) + self._update_ipywidget(None) + + if backend_kwargs["display"]: + # self.check_backend() + display(self.widget) + + def _update_ipywidget(self, change): + import ipywidgets.widgets as widgets + + # if changing the layer_key, no need to retrieve and process traces + retrieve_traces = True + scale_up = False + scale_down = False + if change is not None: + for cname, c in self.controller.items(): + if isinstance(change, dict): + if change["owner"] is c and cname == "layer_key": + retrieve_traces = False + elif isinstance(change, widgets.Button): + if change is c and cname == "plus": + scale_up = True + if change is c and cname == "minus": + scale_down = True + + t_start = self.controller["t_start"].value + window = self.controller["window"].value + layer_key = self.controller["layer_key"].value + segment_index = self.controller["segment_index"].value + mode = self.controller["mode"].value + chan_start, chan_stop = self.controller["channel_inds"].value + + if mode == "line": + self.controller["all_layers"].layout.visibility = "visible" + all_layers = self.controller["all_layers"].value + elif mode == "map": + self.controller["all_layers"].layout.visibility = "hidden" + all_layers = False + + if all_layers: + self.controller["layer_key"].layout.visibility = "hidden" + else: + self.controller["layer_key"].layout.visibility = "visible" + + if chan_start == chan_stop: + chan_stop += 1 + channel_indices = np.arange(chan_start, chan_stop) + + t_stop = self.t_stops[segment_index] + if self.actual_segment_index != segment_index: + # change time_slider limits + self.controller["t_start"].max = t_stop + self.actual_segment_index = segment_index + + # protect limits + if t_start >= t_stop - window: + t_start = t_stop - window + + time_range = np.array([t_start, t_start + window]) + data_plot = self.next_data_plot + + if retrieve_traces: + all_channel_ids = self.recordings[list(self.recordings.keys())[0]].channel_ids + if self.data_plot["order"] is not None: + all_channel_ids = all_channel_ids[self.data_plot["order"]] + channel_ids = all_channel_ids[channel_indices] + if self.data_plot["order_channel_by_depth"]: + order, _ = order_channels_by_depth(self.rec0, channel_ids) + else: + order = None + times, list_traces, frame_range, channel_ids = _get_trace_list( + self.recordings, channel_ids, time_range, segment_index, order, self.return_scaled + ) + self.list_traces = list_traces + else: + times = data_plot["times"] + list_traces = data_plot["list_traces"] + frame_range = data_plot["frame_range"] + channel_ids = data_plot["channel_ids"] + + if all_layers: + layer_keys = self.data_plot["layer_keys"] + recordings = self.recordings + list_traces_plot = self.list_traces + else: + layer_keys = [layer_key] + recordings = {layer_key: self.recordings[layer_key]} + list_traces_plot = [self.list_traces[list(self.recordings.keys()).index(layer_key)]] + + if scale_up: + if mode == "line": + data_plot["vspacing"] *= 0.8 + elif mode == "map": + data_plot["clims"] = { + layer: (1.2 * val[0], 1.2 * val[1]) for layer, val in self.data_plot["clims"].items() + } + if scale_down: + if mode == "line": + data_plot["vspacing"] *= 1.2 + elif mode == "map": + data_plot["clims"] = { + layer: (0.8 * val[0], 0.8 * val[1]) for layer, val in self.data_plot["clims"].items() + } + + self.next_data_plot["vspacing"] = data_plot["vspacing"] + self.next_data_plot["clims"] = data_plot["clims"] + + if mode == "line": + clims = None + elif mode == "map": + clims = {layer_key: self.data_plot["clims"][layer_key]} + + # matplotlib next_data_plot dict update at each call + data_plot["mode"] = mode + data_plot["frame_range"] = frame_range + data_plot["time_range"] = time_range + data_plot["with_colorbar"] = False + data_plot["recordings"] = recordings + data_plot["layer_keys"] = layer_keys + data_plot["list_traces"] = list_traces_plot + data_plot["times"] = times + data_plot["clims"] = clims + data_plot["channel_ids"] = channel_ids + + backend_kwargs = {} + backend_kwargs["ax"] = self.ax + # self.mpl_plotter.do_plot(data_plot, **backend_kwargs) + self.plot_matplotlib(data_plot, **backend_kwargs) + + fig = self.ax.figure + fig.canvas.draw() + fig.canvas.flush_events() + + + def plot_sortingview(self, data_plot, **backend_kwargs): + import sortingview.views as vv + from .sortingview_utils import generate_unit_table_view, make_serializable, handle_display_and_url + + try: + import pyvips + except ImportError: + raise ImportError("To use the timeseries in sorting view you need the pyvips package.") + + backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + dp = to_attr(data_plot) + + assert dp.mode == "map", 'sortingview plot_timeseries is only mode="map"' + + if not dp.order_channel_by_depth: + warnings.warn( + "It is recommended to set 'order_channel_by_depth' to True " "when using the sortingview backend" + ) + + tiled_layers = [] + for layer_key, traces in zip(dp.layer_keys, dp.list_traces): + img = array_to_image( + traces, + clim=dp.clims[layer_key], + num_timepoints_per_row=dp.num_timepoints_per_row, + colormap=dp.cmap, + scalebar=True, + sampling_frequency=dp.recordings[layer_key].get_sampling_frequency(), + ) + + tiled_layers.append(vv.TiledImageLayer(layer_key, img)) + + # view_ts = vv.TiledImage(tile_size=dp.tile_size, layers=tiled_layers) + self.view = vv.TiledImage(tile_size=dp.tile_size, layers=tiled_layers) + + # self.set_view(view_ts) + + # timeseries currently doesn't display on the jupyter backend + backend_kwargs["display"] = False + # self.handle_display_and_url(view_ts, **backend_kwargs) + # return view_ts + + self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) + + + + + + def _get_trace_list(recordings, channel_ids, time_range, segment_index, order=None, return_scaled=False): # function also used in ipywidgets plotter diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index ae0b898035..11fe0b0e92 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -2,7 +2,7 @@ from .base import backend_kwargs_desc # basics -# from .timeseries import TimeseriesWidget +from .timeseries import TimeseriesWidget # waveform from .unit_waveforms import UnitWaveformsWidget @@ -67,7 +67,7 @@ TemplateMetricsWidget, MotionWidget, TemplateSimilarityWidget, - # TimeseriesWidget, + TimeseriesWidget, UnitLocationsWidget, UnitTemplatesWidget, UnitWaveformsWidget, @@ -136,6 +136,7 @@ plot_crosscorrelograms = CrossCorrelogramsWidget plot_spike_locations = SpikeLocationsWidget plot_template_metrics = TemplateMetricsWidget +plot_timeseries = TimeseriesWidget plot_quality_metrics = QualityMetricsWidget plot_motion = MotionWidget plot_template_similarity = TemplateSimilarityWidget From 97410f9dda2133b97609e858684974d39360fa76 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 20 Jul 2023 09:22:45 +0200 Subject: [PATCH 18/31] refactor widget : SpikesOnTracesWidget --- .../widgets/spikes_on_traces.py | 280 ++++++++++++++++-- src/spikeinterface/widgets/widget_list.py | 5 +- 2 files changed, 260 insertions(+), 25 deletions(-) diff --git a/src/spikeinterface/widgets/spikes_on_traces.py b/src/spikeinterface/widgets/spikes_on_traces.py index b50896df4d..9deb346387 100644 --- a/src/spikeinterface/widgets/spikes_on_traces.py +++ b/src/spikeinterface/widgets/spikes_on_traces.py @@ -1,6 +1,6 @@ import numpy as np -from .base import BaseWidget +from .base import BaseWidget, to_attr from .utils import get_unit_colors from .timeseries import TimeseriesWidget from ..core import ChannelSparsity @@ -60,7 +60,7 @@ class SpikesOnTracesWidget(BaseWidget): For 'map' mode and sortingview backend, seconds to render in each row, default 0.2 """ - possible_backends = {} + # possible_backends = {} def __init__( self, @@ -86,28 +86,28 @@ def __init__( **backend_kwargs, ): we = waveform_extractor - recording: BaseRecording = we.recording + # recording: BaseRecording = we.recording sorting: BaseSorting = we.sorting - ts_widget = TimeseriesWidget( - recording, - segment_index, - channel_ids, - order_channel_by_depth, - time_range, - mode, - return_scaled, - cmap, - show_channel_ids, - color_groups, - color, - clim, - tile_size, - seconds_per_row, - with_colorbar, - backend, - **backend_kwargs, - ) + # ts_widget = TimeseriesWidget( + # recording, + # segment_index, + # channel_ids, + # order_channel_by_depth, + # time_range, + # mode, + # return_scaled, + # cmap, + # show_channel_ids, + # color_groups, + # color, + # clim, + # tile_size, + # seconds_per_row, + # with_colorbar, + # backend, + # **backend_kwargs, + # ) if unit_ids is None: unit_ids = sorting.get_unit_ids() @@ -133,9 +133,26 @@ def __init__( # get templates unit_locations = compute_unit_locations(we, outputs="by_unit") + options = dict( + segment_index=segment_index, + channel_ids=channel_ids, + order_channel_by_depth=order_channel_by_depth, + time_range=time_range, + mode=mode, + return_scaled=return_scaled, + cmap=cmap, + show_channel_ids=show_channel_ids, + color_groups=color_groups, + color=color, + clim=clim, + tile_size=tile_size, + with_colorbar=with_colorbar, + ) + plot_data = dict( - timeseries=ts_widget.plot_data, + # timeseries=ts_widget.plot_data, waveform_extractor=waveform_extractor, + options=options, unit_ids=unit_ids, sparsity=sparsity, unit_colors=unit_colors, @@ -143,3 +160,220 @@ def __init__( ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .matplotlib_utils import make_mpl_figure + + from matplotlib.patches import Ellipse + from matplotlib.lines import Line2D + + dp = to_attr(data_plot) + we = dp.waveform_extractor + recording = we.recording + sorting = we.sorting + + + + # first plot time series + # tsplotter = TimeseriesPlotter() + # data_plot["timeseries"]["add_legend"] = False + # tsplotter.do_plot(dp.timeseries, **backend_kwargs) + # self.ax = tsplotter.ax + # self.axes = tsplotter.axes + # self.figure = tsplotter.figure + + # first plot time series + ts_widget = TimeseriesWidget(recording, **dp.options, backend="matplotlib", **backend_kwargs) + self.ax = ts_widget.ax + self.axes = ts_widget.axes + self.figure = ts_widget.figure + + + ax = self.ax + + # we = dp.waveform_extractor + # sorting = dp.waveform_extractor.sorting + # frame_range = dp.timeseries["frame_range"] + # segment_index = dp.timeseries["segment_index"] + # min_y = np.min(dp.timeseries["channel_locations"][:, 1]) + # max_y = np.max(dp.timeseries["channel_locations"][:, 1]) + + frame_range = ts_widget.data_plot["frame_range"] + segment_index = ts_widget.data_plot["segment_index"] + min_y = np.min(ts_widget.data_plot["channel_locations"][:, 1]) + max_y = np.max(ts_widget.data_plot["channel_locations"][:, 1]) + + + # n = len(dp.timeseries["channel_ids"]) + # order = dp.timeseries["order"] + n = len(ts_widget.data_plot["channel_ids"]) + order = ts_widget.data_plot["order"] + + if order is None: + order = np.arange(n) + + if ax.get_legend() is not None: + ax.get_legend().remove() + + # loop through units and plot a scatter of spikes at estimated location + handles = [] + labels = [] + + for unit in dp.unit_ids: + spike_frames = sorting.get_unit_spike_train(unit, segment_index=segment_index) + spike_start, spike_end = np.searchsorted(spike_frames, frame_range) + + chan_ids = dp.sparsity.unit_id_to_channel_ids[unit] + + spike_frames_to_plot = spike_frames[spike_start:spike_end] + + # if dp.timeseries["mode"] == "map": + if dp.options["mode"] == "map": + spike_times_to_plot = sorting.get_unit_spike_train( + unit, segment_index=segment_index, return_times=True + )[spike_start:spike_end] + unit_y_loc = min_y + max_y - dp.unit_locations[unit][1] + # markers = np.ones_like(spike_frames_to_plot) * (min_y + max_y - dp.unit_locations[unit][1]) + width = 2 * 1e-3 + ellipse_kwargs = dict(width=width, height=10, fc="none", ec=dp.unit_colors[unit], lw=2) + patches = [Ellipse((s, unit_y_loc), **ellipse_kwargs) for s in spike_times_to_plot] + for p in patches: + ax.add_patch(p) + handles.append( + Line2D( + [0], + [0], + ls="", + marker="o", + markersize=5, + markeredgewidth=2, + markeredgecolor=dp.unit_colors[unit], + markerfacecolor="none", + ) + ) + labels.append(unit) + else: + # construct waveforms + label_set = False + if len(spike_frames_to_plot) > 0: + # vspacing = dp.timeseries["vspacing"] + # traces = dp.timeseries["list_traces"][0] + vspacing = ts_widget.data_plot["vspacing"] + traces = ts_widget.data_plot["list_traces"][0] + + waveform_idxs = spike_frames_to_plot[:, None] + np.arange(-we.nbefore, we.nafter) - frame_range[0] + # waveform_idxs = np.clip(waveform_idxs, 0, len(dp.timeseries["times"]) - 1) + waveform_idxs = np.clip(waveform_idxs, 0, len(ts_widget.data_plot["times"]) - 1) + + # times = dp.timeseries["times"][waveform_idxs] + times = ts_widget.data_plot["times"][waveform_idxs] + + # discontinuity + times[:, -1] = np.nan + times_r = times.reshape(times.shape[0] * times.shape[1]) + waveforms = traces[waveform_idxs] # [:, :, order] + waveforms_r = waveforms.reshape((waveforms.shape[0] * waveforms.shape[1], waveforms.shape[2])) + + # for i, chan_id in enumerate(dp.timeseries["channel_ids"]): + for i, chan_id in enumerate(ts_widget.data_plot["channel_ids"]): + offset = vspacing * i + if chan_id in chan_ids: + l = ax.plot(times_r, offset + waveforms_r[:, i], color=dp.unit_colors[unit]) + if not label_set: + handles.append(l[0]) + labels.append(unit) + label_set = True + ax.legend(handles, labels) + + + def plot_ipywidgets(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + import ipywidgets.widgets as widgets + from IPython.display import display + from .ipywidgets_utils import check_ipywidget_backend, make_unit_controller + + check_ipywidget_backend() + + self.next_data_plot = data_plot.copy() + + dp = to_attr(data_plot) + we = dp.waveform_extractor + + + ratios = [0.2, 0.8] + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + + backend_kwargs_ts = backend_kwargs.copy() + backend_kwargs_ts["width_cm"] = ratios[1] * backend_kwargs_ts["width_cm"] + backend_kwargs_ts["display"] = False + height_cm = backend_kwargs["height_cm"] + width_cm = backend_kwargs["width_cm"] + + # plot timeseries + # tsplotter = TimeseriesPlotter() + # data_plot["timeseries"]["add_legend"] = False + # tsplotter.do_plot(data_plot["timeseries"], **backend_kwargs_ts) + + # ts_w = tsplotter.widget + # ts_updater = tsplotter.updater + + ts_widget = TimeseriesWidget(we.recording, **dp.options, backend="ipywidgets", **backend_kwargs_ts) + self.ax = ts_widget.ax + self.axes = ts_widget.axes + self.figure = ts_widget.figure + + + # we = data_plot["waveform_extractor"] + + unit_widget, unit_controller = make_unit_controller( + data_plot["unit_ids"], we.unit_ids, ratios[0] * width_cm, height_cm + ) + + self.controller = dict() + # self.controller = ts_updater.controller + self.controller.update(ts_widget.controller) + self.controller.update(unit_controller) + + # mpl_plotter = MplSpikesOnTracesPlotter() + + # self.updater = PlotUpdater(data_plot, mpl_plotter, ts_updater, self.controller) + # for w in self.controller.values(): + # w.observe(self.updater) + + for w in self.controller.values(): + w.observe(self._update_ipywidget) + + + self.widget = widgets.AppLayout(center=ts_widget.widget, left_sidebar=unit_widget, pane_widths=ratios + [0]) + + # a first update + # self.updater(None) + self._update_ipywidget(None) + + if backend_kwargs["display"]: + # self.check_backend() + display(self.widget) + + def _update_ipywidget(self, change): + self.ax.clear() + + unit_ids = self.controller["unit_ids"].value + + # update ts + # self.ts_updater.__call__(change) + + # update data plot + # data_plot = self.data_plot.copy() + data_plot = self.next_data_plot + # data_plot["timeseries"] = self.ts_updater.next_data_plot + data_plot["unit_ids"] = unit_ids + + backend_kwargs = {} + backend_kwargs["ax"] = self.ax + + # self.mpl_plotter.do_plot(data_plot, **backend_kwargs) + self.plot_matplotlib(data_plot, **backend_kwargs) + + self.figure.canvas.draw() + self.figure.canvas.flush_events() diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index 11fe0b0e92..db73dbc5ec 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -18,7 +18,7 @@ # drift/motion # spikes-traces -# from .spikes_on_traces import SpikesOnTracesWidget +from .spikes_on_traces import SpikesOnTracesWidget # PC related @@ -63,7 +63,7 @@ CrossCorrelogramsWidget, QualityMetricsWidget, SpikeLocationsWidget, - # SpikesOnTracesWidget, + SpikesOnTracesWidget, TemplateMetricsWidget, MotionWidget, TemplateSimilarityWidget, @@ -135,6 +135,7 @@ plot_autocorrelograms = AutoCorrelogramsWidget plot_crosscorrelograms = CrossCorrelogramsWidget plot_spike_locations = SpikeLocationsWidget +plot_spikes_on_traces = SpikesOnTracesWidget plot_template_metrics = TemplateMetricsWidget plot_timeseries = TimeseriesWidget plot_quality_metrics = QualityMetricsWidget From d159145376368e8a48bc5340fd868d17a95eff3c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 20 Jul 2023 07:27:31 +0000 Subject: [PATCH 19/31] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../widgets/all_amplitudes_distributions.py | 3 +- src/spikeinterface/widgets/amplitudes.py | 14 ++--- .../widgets/autocorrelograms.py | 4 +- src/spikeinterface/widgets/base.py | 32 +++++------ .../widgets/crosscorrelograms.py | 4 +- .../widgets/ipywidgets_utils.py | 3 +- .../widgets/matplotlib_utils.py | 4 +- src/spikeinterface/widgets/metrics.py | 10 ++-- src/spikeinterface/widgets/motion.py | 4 +- src/spikeinterface/widgets/sorting_summary.py | 47 +++++++++------- .../widgets/sortingview_utils.py | 11 ++-- src/spikeinterface/widgets/spike_locations.py | 11 ++-- .../widgets/spikes_on_traces.py | 16 ++---- .../widgets/template_similarity.py | 4 +- .../widgets/tests/test_widgets.py | 2 +- src/spikeinterface/widgets/timeseries.py | 13 +++-- src/spikeinterface/widgets/unit_depths.py | 1 - src/spikeinterface/widgets/unit_locations.py | 18 ++----- src/spikeinterface/widgets/unit_summary.py | 54 +++++++++++-------- src/spikeinterface/widgets/unit_templates.py | 5 +- src/spikeinterface/widgets/unit_waveforms.py | 9 ++-- src/spikeinterface/widgets/widget_list.py | 9 ++-- 22 files changed, 128 insertions(+), 150 deletions(-) diff --git a/src/spikeinterface/widgets/all_amplitudes_distributions.py b/src/spikeinterface/widgets/all_amplitudes_distributions.py index 18585a4f96..d3cca278c9 100644 --- a/src/spikeinterface/widgets/all_amplitudes_distributions.py +++ b/src/spikeinterface/widgets/all_amplitudes_distributions.py @@ -55,7 +55,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): from matplotlib.patches import Ellipse from matplotlib.lines import Line2D - dp = to_attr(data_plot) # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) # self.make_mpl_figure(**backend_kwargs) @@ -85,4 +84,4 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): if np.max(ylims) < 0: ax.set_ylim(min(ylims), 0) if np.min(ylims) > 0: - ax.set_ylim(0, max(ylims)) \ No newline at end of file + ax.set_ylim(0, max(ylims)) diff --git a/src/spikeinterface/widgets/amplitudes.py b/src/spikeinterface/widgets/amplitudes.py index 7c76d26204..a2a3ccff3b 100644 --- a/src/spikeinterface/widgets/amplitudes.py +++ b/src/spikeinterface/widgets/amplitudes.py @@ -121,9 +121,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): from matplotlib.patches import Ellipse from matplotlib.lines import Line2D - - - dp = to_attr(data_plot) # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) @@ -168,7 +165,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): if dp.plot_legend: # if self.legend is not None: - if hasattr(self, 'legend') and self.legend is not None: + if hasattr(self, "legend") and self.legend is not None: self.legend.remove() self.legend = self.figure.legend( loc="upper center", bbox_to_anchor=(0.5, 1.0), ncol=5, fancybox=True, shadow=True @@ -186,7 +183,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): import ipywidgets.widgets as widgets from IPython.display import display from .ipywidgets_utils import check_ipywidget_backend, make_unit_controller - + check_ipywidget_backend() self.next_data_plot = data_plot.copy() @@ -232,7 +229,10 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.widget = widgets.AppLayout( # center=fig.canvas, left_sidebar=unit_widget, pane_widths=ratios + [0], footer=footer - center=self.figure.canvas, left_sidebar=unit_widget, pane_widths=ratios + [0], footer=footer + center=self.figure.canvas, + left_sidebar=unit_widget, + pane_widths=ratios + [0], + footer=footer, ) # a first update @@ -241,7 +241,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): if backend_kwargs["display"]: # self.check_backend() - display(self.widget) + display(self.widget) def _update_ipywidget(self, change): # self.fig.clear() diff --git a/src/spikeinterface/widgets/autocorrelograms.py b/src/spikeinterface/widgets/autocorrelograms.py index f07246efa6..e7b5014367 100644 --- a/src/spikeinterface/widgets/autocorrelograms.py +++ b/src/spikeinterface/widgets/autocorrelograms.py @@ -41,7 +41,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs): # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) dp = to_attr(data_plot) - # unit_ids = self.make_serializable(dp.unit_ids) + # unit_ids = self.make_serializable(dp.unit_ids) unit_ids = make_serializable(dp.unit_ids) ac_items = [] @@ -63,6 +63,4 @@ def plot_sortingview(self, data_plot, **backend_kwargs): self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) - - AutoCorrelogramsWidget.__doc__ = CrossCorrelogramsWidget.__doc__ diff --git a/src/spikeinterface/widgets/base.py b/src/spikeinterface/widgets/base.py index 7b0ba0454e..a1cc76eb19 100644 --- a/src/spikeinterface/widgets/base.py +++ b/src/spikeinterface/widgets/base.py @@ -19,7 +19,6 @@ def set_default_plotter_backend(backend): default_backend_ = backend - backend_kwargs_desc = { "matplotlib": { "figure": "Matplotlib figure. When None, it is created. Default None", @@ -29,33 +28,37 @@ def set_default_plotter_backend(backend): "figsize": "Size of matplotlib figure. Default None", "figtitle": "The figure title. Default None", }, - 'sortingview': { + "sortingview": { "generate_url": "If True, the figurl URL is generated and printed. Default True", "display": "If True and in jupyter notebook/lab, the widget is displayed in the cell. Default True.", "figlabel": "The figurl figure label. Default None", "height": "The height of the sortingview View in jupyter. Default None", }, - "ipywidgets" : { - "width_cm": "Width of the figure in cm (default 10)", - "height_cm": "Height of the figure in cm (default 6)", - "display": "If True, widgets are immediately displayed", + "ipywidgets": { + "width_cm": "Width of the figure in cm (default 10)", + "height_cm": "Height of the figure in cm (default 6)", + "display": "If True, widgets are immediately displayed", }, - } default_backend_kwargs = { "matplotlib": {"figure": None, "ax": None, "axes": None, "ncols": 5, "figsize": None, "figtitle": None}, "sortingview": {"generate_url": True, "display": True, "figlabel": None, "height": None}, - "ipywidgets" : {"width_cm": 25, "height_cm": 10, "display": True}, + "ipywidgets": {"width_cm": 25, "height_cm": 10, "display": True}, } - class BaseWidget: # this need to be reset in the subclass possible_backends = None - def __init__(self, data_plot=None, backend=None, immediate_plot=True, **backend_kwargs, ): + def __init__( + self, + data_plot=None, + backend=None, + immediate_plot=True, + **backend_kwargs, + ): # every widgets must prepare a dict "plot_data" in the init self.data_plot = data_plot backend = self.check_backend(backend) @@ -70,16 +73,16 @@ def __init__(self, data_plot=None, backend=None, immediate_plot=True, **backend_ ) backend_kwargs_ = default_backend_kwargs[self.backend].copy() backend_kwargs_.update(backend_kwargs) - + self.backend_kwargs = backend_kwargs_ if immediate_plot: - print('immediate_plot', self.backend, self.backend_kwargs) + print("immediate_plot", self.backend, self.backend_kwargs) self.do_plot(self.backend, **self.backend_kwargs) @classmethod def get_possible_backends(cls): - return [ k for k in default_backend_kwargs if hasattr(cls, f"plot_{k}") ] + return [k for k in default_backend_kwargs if hasattr(cls, f"plot_{k}")] def check_backend(self, backend): if backend is None: @@ -88,7 +91,6 @@ def check_backend(self, backend): f"{backend} backend not available! Available backends are: " f"{self.get_possible_backends()}" ) return backend - # def check_backend_kwargs(self, plotter, backend, **backend_kwargs): # plotter_kwargs = plotter.default_backend_kwargs @@ -102,7 +104,7 @@ def check_backend(self, backend): def do_plot(self, backend, **backend_kwargs): # backend = self.check_backend(backend) - func = getattr(self, f'plot_{backend}') + func = getattr(self, f"plot_{backend}") func(self.data_plot, **self.backend_kwargs) # @classmethod diff --git a/src/spikeinterface/widgets/crosscorrelograms.py b/src/spikeinterface/widgets/crosscorrelograms.py index eed76c3e04..4b83e61b69 100644 --- a/src/spikeinterface/widgets/crosscorrelograms.py +++ b/src/spikeinterface/widgets/crosscorrelograms.py @@ -124,9 +124,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs): ) ) - self.view = vv.CrossCorrelograms( - cross_correlograms=cc_items, hide_unit_selector=dp.hide_unit_selector - ) + self.view = vv.CrossCorrelograms(cross_correlograms=cc_items, hide_unit_selector=dp.hide_unit_selector) # self.handle_display_and_url(v_cross_correlograms, **backend_kwargs) # return v_cross_correlograms diff --git a/src/spikeinterface/widgets/ipywidgets_utils.py b/src/spikeinterface/widgets/ipywidgets_utils.py index 4490cc3365..a7c571d1f0 100644 --- a/src/spikeinterface/widgets/ipywidgets_utils.py +++ b/src/spikeinterface/widgets/ipywidgets_utils.py @@ -2,14 +2,13 @@ import numpy as np - def check_ipywidget_backend(): import matplotlib + mpl_backend = matplotlib.get_backend() assert "ipympl" in mpl_backend, "To use the 'ipywidgets' backend, you have to set %matplotlib widget" - def make_timeseries_controller(t_start, t_stop, layer_keys, num_segments, time_range, mode, all_layers, width_cm): time_slider = widgets.FloatSlider( orientation="horizontal", diff --git a/src/spikeinterface/widgets/matplotlib_utils.py b/src/spikeinterface/widgets/matplotlib_utils.py index 6ccaaf5840..fb347552b1 100644 --- a/src/spikeinterface/widgets/matplotlib_utils.py +++ b/src/spikeinterface/widgets/matplotlib_utils.py @@ -65,11 +65,11 @@ def make_mpl_figure(figure=None, ax=None, axes=None, ncols=None, num_axes=None, figure.suptitle(figtitle) return figure, axes, ax - + # self.figure = figure # self.ax = ax # axes is always a 2D array of ax # self.axes = axes # if figtitle is not None: - # self.figure.suptitle(figtitle) \ No newline at end of file + # self.figure.suptitle(figtitle) diff --git a/src/spikeinterface/widgets/metrics.py b/src/spikeinterface/widgets/metrics.py index 207e3a8a6c..6551bb067e 100644 --- a/src/spikeinterface/widgets/metrics.py +++ b/src/spikeinterface/widgets/metrics.py @@ -91,7 +91,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): backend_kwargs["figsize"] = (2 * num_metrics, 2 * num_metrics) # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - backend_kwargs["num_axes"] = num_metrics ** 2 + backend_kwargs["num_axes"] = num_metrics**2 backend_kwargs["ncols"] = num_metrics all_unit_ids = metrics.index.values @@ -128,7 +128,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): self.figure.subplots_adjust(top=0.8, wspace=0.2, hspace=0.2) - def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt import ipywidgets.widgets as widgets @@ -169,7 +168,6 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): for w in self.controller.values(): w.observe(self._update_ipywidget) - self.widget = widgets.AppLayout( center=self.figure.canvas, left_sidebar=unit_widget, @@ -203,7 +201,7 @@ def _update_ipywidget(self, change): # here we do a trick: we just update colors # if hasattr(self.mpl_plotter, "patches"): if hasattr(self, "patches"): - # for p in self.mpl_plotter.patches: + # for p in self.mpl_plotter.patches: for p in self.patches: p.set_color(colors) p.set_sizes(sizes) @@ -242,7 +240,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs): unit_ids = metrics.index.values else: unit_ids = dp.unit_ids - # unit_ids = self.make_serializable(unit_ids) + # unit_ids = self.make_serializable(unit_ids) unit_ids = make_serializable(unit_ids) metrics_sv = [] @@ -283,4 +281,4 @@ def plot_sortingview(self, data_plot, **backend_kwargs): # self.handle_display_and_url(view, **backend_kwargs) # return view - self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) \ No newline at end of file + self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index 48aba8de47..1ebbb71743 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -76,10 +76,8 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): from spikeinterface.sortingcomponents.motion_interpolation import correct_motion_on_peaks - dp = to_attr(data_plot) # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - assert backend_kwargs["axes"] is None assert backend_kwargs["ax"] is None @@ -191,4 +189,4 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax3.set_ylabel("Depth [um]") ax3.set_title("Motion vectors") axes.append(ax3) - self.axes = np.array(axes) \ No newline at end of file + self.axes = np.array(axes) diff --git a/src/spikeinterface/widgets/sorting_summary.py b/src/spikeinterface/widgets/sorting_summary.py index bdf692888f..5498df9a33 100644 --- a/src/spikeinterface/widgets/sorting_summary.py +++ b/src/spikeinterface/widgets/sorting_summary.py @@ -78,7 +78,6 @@ def __init__( unit_table_properties=unit_table_properties, curation=curation, label_choices=label_choices, - max_amplitudes_per_unit=max_amplitudes_per_unit, ) @@ -93,7 +92,6 @@ def plot_sortingview(self, data_plot, **backend_kwargs): unit_ids = dp.unit_ids sparsity = dp.sparsity - # unit_ids = self.make_serializable(dp.unit_ids) unit_ids = make_serializable(dp.unit_ids) @@ -117,21 +115,34 @@ def plot_sortingview(self, data_plot, **backend_kwargs): # ) v_spike_amplitudes = AmplitudesWidget( - we, unit_ids=unit_ids, max_spikes_per_unit=dp.max_amplitudes_per_unit, hide_unit_selector=True, - generate_url=False, display=False, backend="sortingview" + we, + unit_ids=unit_ids, + max_spikes_per_unit=dp.max_amplitudes_per_unit, + hide_unit_selector=True, + generate_url=False, + display=False, + backend="sortingview", ).view v_average_waveforms = UnitTemplatesWidget( - we, unit_ids=unit_ids, sparsity=sparsity, hide_unit_selector=True, - generate_url=False, display=False, backend="sortingview" + we, + unit_ids=unit_ids, + sparsity=sparsity, + hide_unit_selector=True, + generate_url=False, + display=False, + backend="sortingview", + ).view + v_cross_correlograms = CrossCorrelogramsWidget( + we, unit_ids=unit_ids, hide_unit_selector=True, generate_url=False, display=False, backend="sortingview" ).view - v_cross_correlograms = CrossCorrelogramsWidget(we, unit_ids=unit_ids, hide_unit_selector=True, - generate_url=False, display=False, backend="sortingview").view - - v_unit_locations = UnitLocationsWidget(we, unit_ids=unit_ids, hide_unit_selector=True, - generate_url=False, display=False, backend="sortingview").view - - w = TemplateSimilarityWidget(we, unit_ids=unit_ids, immediate_plot=False, - generate_url=False, display=False, backend="sortingview" ) + + v_unit_locations = UnitLocationsWidget( + we, unit_ids=unit_ids, hide_unit_selector=True, generate_url=False, display=False, backend="sortingview" + ).view + + w = TemplateSimilarityWidget( + we, unit_ids=unit_ids, immediate_plot=False, generate_url=False, display=False, backend="sortingview" + ) similarity = w.data_plot["similarity"] print(similarity.shape) @@ -140,9 +151,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs): for i1, u1 in enumerate(unit_ids): for i2, u2 in enumerate(unit_ids): similarity_scores.append( - vv.UnitSimilarityScore( - unit_id1=u1, unit_id2=u2, similarity=similarity[i1, i2].astype("float32") - ) + vv.UnitSimilarityScore(unit_id1=u1, unit_id2=u2, similarity=similarity[i1, i2].astype("float32")) ) # unit ids @@ -179,7 +188,5 @@ def plot_sortingview(self, data_plot, **backend_kwargs): # self.handle_display_and_url(v_summary, **backend_kwargs) # return v_summary - - self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) - + self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) diff --git a/src/spikeinterface/widgets/sortingview_utils.py b/src/spikeinterface/widgets/sortingview_utils.py index 90dfcb77a3..f5339b4bbb 100644 --- a/src/spikeinterface/widgets/sortingview_utils.py +++ b/src/spikeinterface/widgets/sortingview_utils.py @@ -3,8 +3,6 @@ from ..core.core_tools import check_json - - sortingview_backend_kwargs_desc = { "generate_url": "If True, the figurl URL is generated and printed. Default True", "display": "If True and in jupyter notebook/lab, the widget is displayed in the cell. Default True.", @@ -14,7 +12,6 @@ sortingview_default_backend_kwargs = {"generate_url": True, "display": True, "figlabel": None, "height": None} - def make_serializable(*args): dict_to_serialize = {int(i): a for i, a in enumerate(args)} serializable_dict = check_json(dict_to_serialize) @@ -25,6 +22,7 @@ def make_serializable(*args): returns = returns[0] return returns + def is_notebook() -> bool: try: shell = get_ipython().__class__.__name__ @@ -37,6 +35,7 @@ def is_notebook() -> bool: except NameError: return False + def handle_display_and_url(widget, view, **backend_kwargs): url = None if is_notebook() and backend_kwargs["display"]: @@ -44,14 +43,12 @@ def handle_display_and_url(widget, view, **backend_kwargs): if backend_kwargs["generate_url"]: figlabel = backend_kwargs.get("figlabel") if figlabel is None: - # figlabel = widget.default_label + # figlabel = widget.default_label figlabel = "" url = view.url(label=figlabel) print(url) - - return url - + return url def generate_unit_table_view(sorting, unit_properties=None, similarity_scores=None): diff --git a/src/spikeinterface/widgets/spike_locations.py b/src/spikeinterface/widgets/spike_locations.py index d32c3c2f4c..06495409cf 100644 --- a/src/spikeinterface/widgets/spike_locations.py +++ b/src/spikeinterface/widgets/spike_locations.py @@ -111,7 +111,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): from matplotlib.lines import Line2D from probeinterface import ProbeGroup - from probeinterface.plotting import plot_probe + from probeinterface.plotting import plot_probe dp = to_attr(data_plot) # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) @@ -169,7 +169,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ] if dp.plot_legend: # if self.legend is not None: - if hasattr(self, 'legend') and self.legend is not None: + if hasattr(self, "legend") and self.legend is not None: self.legend.remove() self.legend = self.figure.legend( handles, labels, loc="upper center", bbox_to_anchor=(0.5, 1.0), ncol=5, fancybox=True, shadow=True @@ -245,13 +245,11 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): # self.updater(None) self._update_ipywidget(None) - if backend_kwargs["display"]: # self.check_backend() display(self.widget) def _update_ipywidget(self, change): - self.ax.clear() unit_ids = self.controller["unit_ids"].value @@ -272,7 +270,6 @@ def _update_ipywidget(self, change): fig.canvas.draw() fig.canvas.flush_events() - def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv from .sortingview_utils import generate_unit_table_view, make_serializable, handle_display_and_url @@ -282,7 +279,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs): spike_locations = dp.spike_locations # ensure serializable for sortingview - # unit_ids, channel_ids = self.make_serializable(dp.unit_ids, dp.channel_ids) + # unit_ids, channel_ids = self.make_serializable(dp.unit_ids, dp.channel_ids) unit_ids, channel_ids = make_serializable(dp.unit_ids, dp.channel_ids) locations = {str(ch): dp.channel_locations[i_ch].astype("float32") for i_ch, ch in enumerate(channel_ids)} @@ -331,8 +328,6 @@ def plot_sortingview(self, data_plot, **backend_kwargs): self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) - - def estimate_axis_lims(spike_locations, quantile=0.02): # set proper axis limits all_locs = np.concatenate(list(spike_locations.values())) diff --git a/src/spikeinterface/widgets/spikes_on_traces.py b/src/spikeinterface/widgets/spikes_on_traces.py index 9deb346387..0aeb923f38 100644 --- a/src/spikeinterface/widgets/spikes_on_traces.py +++ b/src/spikeinterface/widgets/spikes_on_traces.py @@ -173,8 +173,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): recording = we.recording sorting = we.sorting - - # first plot time series # tsplotter = TimeseriesPlotter() # data_plot["timeseries"]["add_legend"] = False @@ -189,7 +187,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): self.axes = ts_widget.axes self.figure = ts_widget.figure - ax = self.ax # we = dp.waveform_extractor @@ -204,7 +201,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): min_y = np.min(ts_widget.data_plot["channel_locations"][:, 1]) max_y = np.max(ts_widget.data_plot["channel_locations"][:, 1]) - # n = len(dp.timeseries["channel_ids"]) # order = dp.timeseries["order"] n = len(ts_widget.data_plot["channel_ids"]) @@ -263,10 +259,10 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): traces = ts_widget.data_plot["list_traces"][0] waveform_idxs = spike_frames_to_plot[:, None] + np.arange(-we.nbefore, we.nafter) - frame_range[0] - # waveform_idxs = np.clip(waveform_idxs, 0, len(dp.timeseries["times"]) - 1) + # waveform_idxs = np.clip(waveform_idxs, 0, len(dp.timeseries["times"]) - 1) waveform_idxs = np.clip(waveform_idxs, 0, len(ts_widget.data_plot["times"]) - 1) - # times = dp.timeseries["times"][waveform_idxs] + # times = dp.timeseries["times"][waveform_idxs] times = ts_widget.data_plot["times"][waveform_idxs] # discontinuity @@ -286,7 +282,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): label_set = True ax.legend(handles, labels) - def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt import ipywidgets.widgets as widgets @@ -300,7 +295,6 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): dp = to_attr(data_plot) we = dp.waveform_extractor - ratios = [0.2, 0.8] # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) @@ -323,15 +317,14 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.axes = ts_widget.axes self.figure = ts_widget.figure - # we = data_plot["waveform_extractor"] - + unit_widget, unit_controller = make_unit_controller( data_plot["unit_ids"], we.unit_ids, ratios[0] * width_cm, height_cm ) self.controller = dict() - # self.controller = ts_updater.controller + # self.controller = ts_updater.controller self.controller.update(ts_widget.controller) self.controller.update(unit_controller) @@ -344,7 +337,6 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): for w in self.controller.values(): w.observe(self._update_ipywidget) - self.widget = widgets.AppLayout(center=ts_widget.widget, left_sidebar=unit_widget, pane_widths=ratios + [0]) # a first update diff --git a/src/spikeinterface/widgets/template_similarity.py b/src/spikeinterface/widgets/template_similarity.py index 93b9a49f49..a6e0356db1 100644 --- a/src/spikeinterface/widgets/template_similarity.py +++ b/src/spikeinterface/widgets/template_similarity.py @@ -62,7 +62,7 @@ def __init__( ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) - + def plot_matplotlib(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt from .matplotlib_utils import make_mpl_figure @@ -91,7 +91,6 @@ def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv from .sortingview_utils import generate_unit_table_view, make_serializable, handle_display_and_url - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) dp = to_attr(data_plot) @@ -112,4 +111,3 @@ def plot_sortingview(self, data_plot, **backend_kwargs): # self.handle_display_and_url(view, **backend_kwargs) # return view self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) - diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 4ddec4134b..610da470e8 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -13,7 +13,7 @@ from spikeinterface import extract_waveforms, load_waveforms, download_dataset, compute_sparsity -# from spikeinterface.widgets import HAVE_MPL, HAVE_SV +# from spikeinterface.widgets import HAVE_MPL, HAVE_SV import spikeinterface.extractors as se diff --git a/src/spikeinterface/widgets/timeseries.py b/src/spikeinterface/widgets/timeseries.py index 0e82c85b94..86e886babc 100644 --- a/src/spikeinterface/widgets/timeseries.py +++ b/src/spikeinterface/widgets/timeseries.py @@ -284,7 +284,12 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt import ipywidgets.widgets as widgets from IPython.display import display - from .ipywidgets_utils import check_ipywidget_backend, make_timeseries_controller, make_channel_controller, make_scale_controller + from .ipywidgets_utils import ( + check_ipywidget_backend, + make_timeseries_controller, + make_channel_controller, + make_scale_controller, + ) check_ipywidget_backend() @@ -499,7 +504,6 @@ def _update_ipywidget(self, change): fig.canvas.draw() fig.canvas.flush_events() - def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv from .sortingview_utils import generate_unit_table_view, make_serializable, handle_display_and_url @@ -545,11 +549,6 @@ def plot_sortingview(self, data_plot, **backend_kwargs): self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) - - - - - def _get_trace_list(recordings, channel_ids, time_range, segment_index, order=None, return_scaled=False): # function also used in ipywidgets plotter k0 = list(recordings.keys())[0] diff --git a/src/spikeinterface/widgets/unit_depths.py b/src/spikeinterface/widgets/unit_depths.py index 9b710815e4..faf9198c0d 100644 --- a/src/spikeinterface/widgets/unit_depths.py +++ b/src/spikeinterface/widgets/unit_depths.py @@ -74,4 +74,3 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax.set_xlabel("amplitude") ax.set_ylabel("depth [um]") ax.set_xlim(0, max(dp.unit_amplitudes) * 1.2) - diff --git a/src/spikeinterface/widgets/unit_locations.py b/src/spikeinterface/widgets/unit_locations.py index 725a4c3023..9e35f7b32c 100644 --- a/src/spikeinterface/widgets/unit_locations.py +++ b/src/spikeinterface/widgets/unit_locations.py @@ -79,7 +79,7 @@ def __init__( plot_legend=plot_legend, hide_axis=hide_axis, ) - + BaseWidget.__init__(self, data_plot, backend=backend, **backend_kwargs) def plot_matplotlib(self, data_plot, **backend_kwargs): @@ -90,17 +90,12 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): from matplotlib.patches import Ellipse from matplotlib.lines import Line2D - - - dp = to_attr(data_plot) # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - # self.make_mpl_figure(**backend_kwargs) self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - unit_locations = dp.unit_locations probegroup = ProbeGroup.from_dict(dp.probegroup_dict) @@ -161,8 +156,8 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ] if dp.plot_legend: - if hasattr(self, 'legend') and self.legend is not None: - # if self.legend is not None: + if hasattr(self, "legend") and self.legend is not None: + # if self.legend is not None: self.legend.remove() self.legend = self.figure.legend( handles, labels, loc="upper center", bbox_to_anchor=(0.5, 1.0), ncol=5, fancybox=True, shadow=True @@ -171,9 +166,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): if dp.hide_axis: self.ax.axis("off") - - - def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt import ipywidgets.widgets as widgets @@ -188,7 +180,6 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - width_cm = backend_kwargs["width_cm"] height_cm = backend_kwargs["height_cm"] @@ -227,7 +218,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): if backend_kwargs["display"]: display(self.widget) - + def _update_ipywidget(self, change): self.ax.clear() @@ -283,4 +274,3 @@ def plot_sortingview(self, data_plot, **backend_kwargs): # self.handle_display_and_url(view, **backend_kwargs) self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) - diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index 68fa8b77d2..66f522e3ca 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -109,13 +109,12 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): from .matplotlib_utils import make_mpl_figure dp = to_attr(data_plot) - + unit_id = dp.unit_id we = dp.we unit_colors = dp.unit_colors sparsity = dp.sparsity - # force the figure without axes if "figsize" not in backend_kwargs: backend_kwargs["figsize"] = (18, 7) @@ -136,7 +135,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ncols += 1 # if dp.plot_data_amplitudes is not None : if we.is_extension("spike_amplitudes"): - nrows += 1 gs = fig.add_gridspec(nrows, ncols) @@ -145,9 +143,9 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax1 = fig.add_subplot(gs[:2, 0]) # UnitLocationsPlotter().do_plot(dp.plot_data_unit_locations, ax=ax1) w = UnitLocationsWidget( - we, unit_ids=[unit_id], unit_colors=unit_colors, plot_legend=False, - backend='matplotlib', ax=ax1) - + we, unit_ids=[unit_id], unit_colors=unit_colors, plot_legend=False, backend="matplotlib", ax=ax1 + ) + unit_locations = we.load_extension("unit_locations").get_data(outputs="by_unit") unit_location = unit_locations[unit_id] # x, y = dp.unit_location[0], dp.unit_location[1] @@ -161,22 +159,30 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax2 = fig.add_subplot(gs[:2, 1]) # UnitWaveformPlotter().do_plot(dp.plot_data_waveforms, ax=ax2) w = UnitWaveformsWidget( - we, - unit_ids=[unit_id], - unit_colors=unit_colors, - plot_templates=True, - same_axis=True, - plot_legend=False, - sparsity=sparsity, - backend='matplotlib', ax=ax2) - + we, + unit_ids=[unit_id], + unit_colors=unit_colors, + plot_templates=True, + same_axis=True, + plot_legend=False, + sparsity=sparsity, + backend="matplotlib", + ax=ax2, + ) + ax2.set_title(None) ax3 = fig.add_subplot(gs[:2, 2]) # UnitWaveformDensityMapPlotter().do_plot(dp.plot_data_waveform_density, ax=ax3) UnitWaveformDensityMapWidget( - we, unit_ids=[unit_id], unit_colors=unit_colors, use_max_channel=True, same_axis=False, - backend='matplotlib', ax=ax3) + we, + unit_ids=[unit_id], + unit_colors=unit_colors, + use_max_channel=True, + same_axis=False, + backend="matplotlib", + ax=ax3, + ) ax3.set_ylabel(None) # if dp.plot_data_acc is not None: @@ -187,10 +193,10 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): we, unit_ids=[unit_id], unit_colors=unit_colors, - backend='matplotlib', ax=ax4, + backend="matplotlib", + ax=ax4, ) - ax4.set_title(None) ax4.set_yticks([]) @@ -201,7 +207,13 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): axes = np.array([ax5, ax6]) # AmplitudesPlotter().do_plot(dp.plot_data_amplitudes, axes=axes) AmplitudesWidget( - we, unit_ids=[unit_id], unit_colors=unit_colors, plot_legend=False, plot_histograms=True, - backend='matplotlib', axes=axes) + we, + unit_ids=[unit_id], + unit_colors=unit_colors, + plot_legend=False, + plot_histograms=True, + backend="matplotlib", + axes=axes, + ) fig.suptitle(f"unit_id: {dp.unit_id}") diff --git a/src/spikeinterface/widgets/unit_templates.py b/src/spikeinterface/widgets/unit_templates.py index 84856d2df4..04b26e300f 100644 --- a/src/spikeinterface/widgets/unit_templates.py +++ b/src/spikeinterface/widgets/unit_templates.py @@ -1,5 +1,6 @@ from .unit_waveforms import UnitWaveformsWidget -from .base import to_attr +from .base import to_attr + class UnitTemplatesWidget(UnitWaveformsWidget): # possible_backends = {} @@ -56,6 +57,4 @@ def plot_sortingview(self, data_plot, **backend_kwargs): self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) - - UnitTemplatesWidget.__doc__ = UnitWaveformsWidget.__doc__ diff --git a/src/spikeinterface/widgets/unit_waveforms.py b/src/spikeinterface/widgets/unit_waveforms.py index 49c75bf046..833f13881d 100644 --- a/src/spikeinterface/widgets/unit_waveforms.py +++ b/src/spikeinterface/widgets/unit_waveforms.py @@ -250,7 +250,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): if dp.same_axis and dp.plot_legend: # if self.legend is not None: - if hasattr(self, 'legend') and self.legend is not None: + if hasattr(self, "legend") and self.legend is not None: self.legend.remove() self.legend = self.figure.legend( loc="upper center", bbox_to_anchor=(0.5, 1.0), ncol=5, fancybox=True, shadow=True @@ -326,7 +326,6 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): for w in self.controller.values(): w.observe(self._update_ipywidget) - self.widget = widgets.AppLayout( center=self.fig_wf.canvas, left_sidebar=unit_widget, @@ -342,7 +341,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): if backend_kwargs["display"]: # self.check_backend() display(self.widget) - + def _update_ipywidget(self, change): self.fig_wf.clear() self.ax_probe.clear() @@ -373,10 +372,10 @@ def _update_ipywidget(self, change): # self.mpl_plotter.do_plot(data_plot, **backend_kwargs) self.plot_matplotlib(data_plot, **backend_kwargs) if same_axis: - # self.mpl_plotter.ax.axis("equal") + # self.mpl_plotter.ax.axis("equal") self.ax.axis("equal") if hide_axis: - # self.mpl_plotter.ax.axis("off") + # self.mpl_plotter.ax.axis("off") self.ax.axis("off") else: if hide_axis: diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index db73dbc5ec..a753c78d4a 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -1,4 +1,4 @@ -# from .base import define_widget_function_from_class +# from .base import define_widget_function_from_class from .base import backend_kwargs_desc # basics @@ -90,12 +90,12 @@ **backend_kwargs: kwargs {backend_kwargs} """ - # backend_str = f" {list(wcls.possible_backends.keys())}" + # backend_str = f" {list(wcls.possible_backends.keys())}" backend_str = f" {wcls.get_possible_backends()}" backend_kwargs_str = "" - # for backend, backend_plotter in wcls.possible_backends.items(): + # for backend, backend_plotter in wcls.possible_backends.items(): for backend in wcls.get_possible_backends(): - # backend_kwargs_desc = backend_plotter.backend_kwargs_desc + # backend_kwargs_desc = backend_plotter.backend_kwargs_desc kwargs_desc = backend_kwargs_desc[backend] if len(kwargs_desc) > 0: backend_kwargs_str += f"\n {backend}:\n\n" @@ -147,4 +147,3 @@ plot_unit_depths = UnitDepthsWidget plot_unit_summary = UnitSummaryWidget plot_sorting_summary = SortingSummaryWidget - From 9768010ab2721c6814ca0aa00d395f00b9b4d84c Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 20 Jul 2023 10:30:02 +0200 Subject: [PATCH 20/31] wip --- src/spikeinterface/widgets/base.py | 8 ++++---- src/spikeinterface/widgets/sortingview_utils.py | 5 +++-- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/widgets/base.py b/src/spikeinterface/widgets/base.py index 7b0ba0454e..219787d87a 100644 --- a/src/spikeinterface/widgets/base.py +++ b/src/spikeinterface/widgets/base.py @@ -74,8 +74,8 @@ def __init__(self, data_plot=None, backend=None, immediate_plot=True, **backend_ self.backend_kwargs = backend_kwargs_ if immediate_plot: - print('immediate_plot', self.backend, self.backend_kwargs) - self.do_plot(self.backend, **self.backend_kwargs) + # print('immediate_plot', self.backend, self.backend_kwargs) + self.do_plot() @classmethod def get_possible_backends(cls): @@ -99,10 +99,10 @@ def check_backend(self, backend): # f"Possible backend keyword arguments for {backend} are: {list(plotter_kwargs.keys())}" # ) - def do_plot(self, backend, **backend_kwargs): + def do_plot(self): # backend = self.check_backend(backend) - func = getattr(self, f'plot_{backend}') + func = getattr(self, f'plot_{self.backend}') func(self.data_plot, **self.backend_kwargs) # @classmethod diff --git a/src/spikeinterface/widgets/sortingview_utils.py b/src/spikeinterface/widgets/sortingview_utils.py index 90dfcb77a3..c513c1f2b6 100644 --- a/src/spikeinterface/widgets/sortingview_utils.py +++ b/src/spikeinterface/widgets/sortingview_utils.py @@ -39,8 +39,9 @@ def is_notebook() -> bool: def handle_display_and_url(widget, view, **backend_kwargs): url = None - if is_notebook() and backend_kwargs["display"]: - display(view.jupyter(height=backend_kwargs["height"])) + # TODO: put this back when figurl-jupyter is working again + # if is_notebook() and backend_kwargs["display"]: + # display(view.jupyter(height=backend_kwargs["height"])) if backend_kwargs["generate_url"]: figlabel = backend_kwargs.get("figlabel") if figlabel is None: From 6a7d337b91d0a70de91ae5efc814bbee8f1a80de Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 20 Jul 2023 10:39:43 +0200 Subject: [PATCH 21/31] remove old backend folder (matplotlib, ipywidgets, sortingview) not needed anymore --- .../widgets/ipywidgets/__init__.py | 9 - .../widgets/ipywidgets/amplitudes.py | 99 -------- .../widgets/ipywidgets/base_ipywidgets.py | 20 -- .../widgets/ipywidgets/metrics.py | 108 -------- .../widgets/ipywidgets/quality_metrics.py | 9 - .../widgets/ipywidgets/spike_locations.py | 97 -------- .../widgets/ipywidgets/spikes_on_traces.py | 145 ----------- .../widgets/ipywidgets/template_metrics.py | 9 - .../widgets/ipywidgets/timeseries.py | 232 ------------------ .../widgets/ipywidgets/unit_locations.py | 91 ------- .../widgets/ipywidgets/unit_templates.py | 11 - .../widgets/ipywidgets/unit_waveforms.py | 169 ------------- .../widgets/ipywidgets/utils.py | 97 -------- .../widgets/matplotlib/__init__.py | 17 -- .../all_amplitudes_distributions.py | 41 ---- .../widgets/matplotlib/amplitudes.py | 69 ------ .../widgets/matplotlib/autocorrelograms.py | 30 --- .../widgets/matplotlib/base_mpl.py | 102 -------- .../widgets/matplotlib/crosscorrelograms.py | 39 --- .../widgets/matplotlib/metrics.py | 50 ---- .../widgets/matplotlib/motion.py | 129 ---------- .../widgets/matplotlib/quality_metrics.py | 9 - .../widgets/matplotlib/spike_locations.py | 96 -------- .../widgets/matplotlib/spikes_on_traces.py | 104 -------- .../widgets/matplotlib/template_metrics.py | 9 - .../widgets/matplotlib/template_similarity.py | 30 --- .../widgets/matplotlib/timeseries.py | 70 ------ .../widgets/matplotlib/unit_depths.py | 22 -- .../widgets/matplotlib/unit_locations.py | 95 ------- .../widgets/matplotlib/unit_summary.py | 73 ------ .../widgets/matplotlib/unit_templates.py | 9 - .../widgets/matplotlib/unit_waveforms.py | 95 ------- .../matplotlib/unit_waveforms_density_map.py | 77 ------ .../widgets/sortingview/__init__.py | 11 - .../widgets/sortingview/amplitudes.py | 36 --- .../widgets/sortingview/autocorrelograms.py | 34 --- .../widgets/sortingview/base_sortingview.py | 103 -------- .../widgets/sortingview/crosscorrelograms.py | 37 --- .../widgets/sortingview/metrics.py | 61 ----- .../widgets/sortingview/quality_metrics.py | 11 - .../widgets/sortingview/sorting_summary.py | 86 ------- .../widgets/sortingview/spike_locations.py | 64 ----- .../widgets/sortingview/template_metrics.py | 11 - .../sortingview/template_similarity.py | 32 --- .../widgets/sortingview/timeseries.py | 54 ---- .../widgets/sortingview/unit_locations.py | 44 ---- .../widgets/sortingview/unit_templates.py | 54 ---- 47 files changed, 2900 deletions(-) delete mode 100644 src/spikeinterface/widgets/ipywidgets/__init__.py delete mode 100644 src/spikeinterface/widgets/ipywidgets/amplitudes.py delete mode 100644 src/spikeinterface/widgets/ipywidgets/base_ipywidgets.py delete mode 100644 src/spikeinterface/widgets/ipywidgets/metrics.py delete mode 100644 src/spikeinterface/widgets/ipywidgets/quality_metrics.py delete mode 100644 src/spikeinterface/widgets/ipywidgets/spike_locations.py delete mode 100644 src/spikeinterface/widgets/ipywidgets/spikes_on_traces.py delete mode 100644 src/spikeinterface/widgets/ipywidgets/template_metrics.py delete mode 100644 src/spikeinterface/widgets/ipywidgets/timeseries.py delete mode 100644 src/spikeinterface/widgets/ipywidgets/unit_locations.py delete mode 100644 src/spikeinterface/widgets/ipywidgets/unit_templates.py delete mode 100644 src/spikeinterface/widgets/ipywidgets/unit_waveforms.py delete mode 100644 src/spikeinterface/widgets/ipywidgets/utils.py delete mode 100644 src/spikeinterface/widgets/matplotlib/__init__.py delete mode 100644 src/spikeinterface/widgets/matplotlib/all_amplitudes_distributions.py delete mode 100644 src/spikeinterface/widgets/matplotlib/amplitudes.py delete mode 100644 src/spikeinterface/widgets/matplotlib/autocorrelograms.py delete mode 100644 src/spikeinterface/widgets/matplotlib/base_mpl.py delete mode 100644 src/spikeinterface/widgets/matplotlib/crosscorrelograms.py delete mode 100644 src/spikeinterface/widgets/matplotlib/metrics.py delete mode 100644 src/spikeinterface/widgets/matplotlib/motion.py delete mode 100644 src/spikeinterface/widgets/matplotlib/quality_metrics.py delete mode 100644 src/spikeinterface/widgets/matplotlib/spike_locations.py delete mode 100644 src/spikeinterface/widgets/matplotlib/spikes_on_traces.py delete mode 100644 src/spikeinterface/widgets/matplotlib/template_metrics.py delete mode 100644 src/spikeinterface/widgets/matplotlib/template_similarity.py delete mode 100644 src/spikeinterface/widgets/matplotlib/timeseries.py delete mode 100644 src/spikeinterface/widgets/matplotlib/unit_depths.py delete mode 100644 src/spikeinterface/widgets/matplotlib/unit_locations.py delete mode 100644 src/spikeinterface/widgets/matplotlib/unit_summary.py delete mode 100644 src/spikeinterface/widgets/matplotlib/unit_templates.py delete mode 100644 src/spikeinterface/widgets/matplotlib/unit_waveforms.py delete mode 100644 src/spikeinterface/widgets/matplotlib/unit_waveforms_density_map.py delete mode 100644 src/spikeinterface/widgets/sortingview/__init__.py delete mode 100644 src/spikeinterface/widgets/sortingview/amplitudes.py delete mode 100644 src/spikeinterface/widgets/sortingview/autocorrelograms.py delete mode 100644 src/spikeinterface/widgets/sortingview/base_sortingview.py delete mode 100644 src/spikeinterface/widgets/sortingview/crosscorrelograms.py delete mode 100644 src/spikeinterface/widgets/sortingview/metrics.py delete mode 100644 src/spikeinterface/widgets/sortingview/quality_metrics.py delete mode 100644 src/spikeinterface/widgets/sortingview/sorting_summary.py delete mode 100644 src/spikeinterface/widgets/sortingview/spike_locations.py delete mode 100644 src/spikeinterface/widgets/sortingview/template_metrics.py delete mode 100644 src/spikeinterface/widgets/sortingview/template_similarity.py delete mode 100644 src/spikeinterface/widgets/sortingview/timeseries.py delete mode 100644 src/spikeinterface/widgets/sortingview/unit_locations.py delete mode 100644 src/spikeinterface/widgets/sortingview/unit_templates.py diff --git a/src/spikeinterface/widgets/ipywidgets/__init__.py b/src/spikeinterface/widgets/ipywidgets/__init__.py deleted file mode 100644 index 63d1b3a486..0000000000 --- a/src/spikeinterface/widgets/ipywidgets/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from .amplitudes import AmplitudesPlotter -from .quality_metrics import QualityMetricsPlotter -from .spike_locations import SpikeLocationsPlotter -from .spikes_on_traces import SpikesOnTracesPlotter -from .template_metrics import TemplateMetricsPlotter -from .timeseries import TimeseriesPlotter -from .unit_locations import UnitLocationsPlotter -from .unit_templates import UnitTemplatesPlotter -from .unit_waveforms import UnitWaveformPlotter diff --git a/src/spikeinterface/widgets/ipywidgets/amplitudes.py b/src/spikeinterface/widgets/ipywidgets/amplitudes.py deleted file mode 100644 index dc55b927e0..0000000000 --- a/src/spikeinterface/widgets/ipywidgets/amplitudes.py +++ /dev/null @@ -1,99 +0,0 @@ -import numpy as np - -import matplotlib.pyplot as plt -import ipywidgets.widgets as widgets - - -from ..base import to_attr - -from .base_ipywidgets import IpywidgetsPlotter -from .utils import make_unit_controller - -from ..amplitudes import AmplitudesWidget -from ..matplotlib.amplitudes import AmplitudesPlotter as MplAmplitudesPlotter - -from IPython.display import display - - -class AmplitudesPlotter(IpywidgetsPlotter): - def do_plot(self, data_plot, **backend_kwargs): - cm = 1 / 2.54 - we = data_plot["waveform_extractor"] - - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - width_cm = backend_kwargs["width_cm"] - height_cm = backend_kwargs["height_cm"] - - ratios = [0.15, 0.85] - - with plt.ioff(): - output = widgets.Output() - with output: - fig = plt.figure(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) - plt.show() - - data_plot["unit_ids"] = data_plot["unit_ids"][:1] - unit_widget, unit_controller = make_unit_controller( - data_plot["unit_ids"], we.unit_ids, ratios[0] * width_cm, height_cm - ) - - plot_histograms = widgets.Checkbox( - value=data_plot["plot_histograms"], - description="plot histograms", - disabled=False, - ) - - footer = plot_histograms - - self.controller = {"plot_histograms": plot_histograms} - self.controller.update(unit_controller) - - mpl_plotter = MplAmplitudesPlotter() - - self.updater = PlotUpdater(data_plot, mpl_plotter, fig, self.controller) - for w in self.controller.values(): - w.observe(self.updater) - - self.widget = widgets.AppLayout( - center=fig.canvas, left_sidebar=unit_widget, pane_widths=ratios + [0], footer=footer - ) - - # a first update - self.updater(None) - - if backend_kwargs["display"]: - self.check_backend() - display(self.widget) - - -AmplitudesPlotter.register(AmplitudesWidget) - - -class PlotUpdater: - def __init__(self, data_plot, mpl_plotter, fig, controller): - self.data_plot = data_plot - self.mpl_plotter = mpl_plotter - self.fig = fig - self.controller = controller - - self.we = data_plot["waveform_extractor"] - self.next_data_plot = data_plot.copy() - - def __call__(self, change): - self.fig.clear() - - unit_ids = self.controller["unit_ids"].value - plot_histograms = self.controller["plot_histograms"].value - - # matplotlib next_data_plot dict update at each call - data_plot = self.next_data_plot - data_plot["unit_ids"] = unit_ids - data_plot["plot_histograms"] = plot_histograms - - backend_kwargs = {} - backend_kwargs["figure"] = self.fig - - self.mpl_plotter.do_plot(data_plot, **backend_kwargs) - - self.fig.canvas.draw() - self.fig.canvas.flush_events() diff --git a/src/spikeinterface/widgets/ipywidgets/base_ipywidgets.py b/src/spikeinterface/widgets/ipywidgets/base_ipywidgets.py deleted file mode 100644 index e0eff7f330..0000000000 --- a/src/spikeinterface/widgets/ipywidgets/base_ipywidgets.py +++ /dev/null @@ -1,20 +0,0 @@ -from spikeinterface.widgets.base import BackendPlotter - -import matplotlib as mpl -import matplotlib.pyplot as plt -from matplotlib import gridspec -import numpy as np - - -class IpywidgetsPlotter(BackendPlotter): - backend = "ipywidgets" - backend_kwargs_desc = { - "width_cm": "Width of the figure in cm (default 10)", - "height_cm": "Height of the figure in cm (default 6)", - "display": "If True, widgets are immediately displayed", - } - default_backend_kwargs = {"width_cm": 25, "height_cm": 10, "display": True} - - def check_backend(self): - mpl_backend = mpl.get_backend() - assert "ipympl" in mpl_backend, "To use the 'ipywidgets' backend, you have to set %matplotlib widget" diff --git a/src/spikeinterface/widgets/ipywidgets/metrics.py b/src/spikeinterface/widgets/ipywidgets/metrics.py deleted file mode 100644 index ba6859b2a1..0000000000 --- a/src/spikeinterface/widgets/ipywidgets/metrics.py +++ /dev/null @@ -1,108 +0,0 @@ -import numpy as np - -import matplotlib.pyplot as plt -import ipywidgets.widgets as widgets - -from matplotlib.lines import Line2D - -from ..base import to_attr - -from .base_ipywidgets import IpywidgetsPlotter -from .utils import make_unit_controller - -from ..matplotlib.metrics import MetricsPlotter as MplMetricsPlotter - -from IPython.display import display - - -class MetricsPlotter(IpywidgetsPlotter): - def do_plot(self, data_plot, **backend_kwargs): - cm = 1 / 2.54 - - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - width_cm = backend_kwargs["width_cm"] - height_cm = backend_kwargs["height_cm"] - - ratios = [0.15, 0.85] - - with plt.ioff(): - output = widgets.Output() - with output: - fig = plt.figure(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) - plt.show() - if data_plot["unit_ids"] is None: - data_plot["unit_ids"] = [] - - unit_widget, unit_controller = make_unit_controller( - data_plot["unit_ids"], list(data_plot["unit_colors"].keys()), ratios[0] * width_cm, height_cm - ) - - self.controller = unit_controller - - mpl_plotter = MplMetricsPlotter() - - self.updater = PlotUpdater(data_plot, mpl_plotter, fig, self.controller) - for w in self.controller.values(): - w.observe(self.updater) - - self.widget = widgets.AppLayout( - center=fig.canvas, - left_sidebar=unit_widget, - pane_widths=ratios + [0], - ) - - # a first update - self.updater(None) - - if backend_kwargs["display"]: - self.check_backend() - display(self.widget) - - -class PlotUpdater: - def __init__(self, data_plot, mpl_plotter, fig, controller): - self.data_plot = data_plot - self.mpl_plotter = mpl_plotter - self.fig = fig - self.controller = controller - self.unit_colors = data_plot["unit_colors"] - - self.next_data_plot = data_plot.copy() - - def __call__(self, change): - unit_ids = self.controller["unit_ids"].value - - # matplotlib next_data_plot dict update at each call - all_units = list(self.unit_colors.keys()) - colors = [] - sizes = [] - for unit in all_units: - color = "gray" if unit not in unit_ids else self.unit_colors[unit] - size = 1 if unit not in unit_ids else 5 - colors.append(color) - sizes.append(size) - - # here we do a trick: we just update colors - if hasattr(self.mpl_plotter, "patches"): - for p in self.mpl_plotter.patches: - p.set_color(colors) - p.set_sizes(sizes) - else: - backend_kwargs = {} - backend_kwargs["figure"] = self.fig - self.mpl_plotter.do_plot(self.data_plot, **backend_kwargs) - - if len(unit_ids) > 0: - for l in self.fig.legends: - l.remove() - handles = [ - Line2D([0], [0], ls="", marker="o", markersize=5, markeredgewidth=2, color=self.unit_colors[unit]) - for unit in unit_ids - ] - labels = unit_ids - self.fig.legend( - handles, labels, loc="upper center", bbox_to_anchor=(0.5, 1.0), ncol=5, fancybox=True, shadow=True - ) - - self.fig.canvas.draw() - self.fig.canvas.flush_events() diff --git a/src/spikeinterface/widgets/ipywidgets/quality_metrics.py b/src/spikeinterface/widgets/ipywidgets/quality_metrics.py deleted file mode 100644 index 3fc368770b..0000000000 --- a/src/spikeinterface/widgets/ipywidgets/quality_metrics.py +++ /dev/null @@ -1,9 +0,0 @@ -from ..quality_metrics import QualityMetricsWidget -from .metrics import MetricsPlotter - - -class QualityMetricsPlotter(MetricsPlotter): - pass - - -QualityMetricsPlotter.register(QualityMetricsWidget) diff --git a/src/spikeinterface/widgets/ipywidgets/spike_locations.py b/src/spikeinterface/widgets/ipywidgets/spike_locations.py deleted file mode 100644 index 633eb0ac39..0000000000 --- a/src/spikeinterface/widgets/ipywidgets/spike_locations.py +++ /dev/null @@ -1,97 +0,0 @@ -import numpy as np - -import matplotlib.pyplot as plt -import ipywidgets.widgets as widgets - - -from ..base import to_attr - -from .base_ipywidgets import IpywidgetsPlotter -from .utils import make_unit_controller - -from ..spike_locations import SpikeLocationsWidget -from ..matplotlib.spike_locations import ( - SpikeLocationsPlotter as MplSpikeLocationsPlotter, -) - -from IPython.display import display - - -class SpikeLocationsPlotter(IpywidgetsPlotter): - def do_plot(self, data_plot, **backend_kwargs): - cm = 1 / 2.54 - - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - width_cm = backend_kwargs["width_cm"] - height_cm = backend_kwargs["height_cm"] - - ratios = [0.15, 0.85] - - with plt.ioff(): - output = widgets.Output() - with output: - fig, ax = plt.subplots(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) - plt.show() - - data_plot["unit_ids"] = data_plot["unit_ids"][:1] - - unit_widget, unit_controller = make_unit_controller( - data_plot["unit_ids"], - list(data_plot["unit_colors"].keys()), - ratios[0] * width_cm, - height_cm, - ) - - self.controller = unit_controller - - mpl_plotter = MplSpikeLocationsPlotter() - - self.updater = PlotUpdater(data_plot, mpl_plotter, ax, self.controller) - for w in self.controller.values(): - w.observe(self.updater) - - self.widget = widgets.AppLayout( - center=fig.canvas, - left_sidebar=unit_widget, - pane_widths=ratios + [0], - ) - - # a first update - self.updater(None) - - if backend_kwargs["display"]: - self.check_backend() - display(self.widget) - - -SpikeLocationsPlotter.register(SpikeLocationsWidget) - - -class PlotUpdater: - def __init__(self, data_plot, mpl_plotter, ax, controller): - self.data_plot = data_plot - self.mpl_plotter = mpl_plotter - self.ax = ax - self.controller = controller - - self.next_data_plot = data_plot.copy() - - def __call__(self, change): - self.ax.clear() - - unit_ids = self.controller["unit_ids"].value - - # matplotlib next_data_plot dict update at each call - data_plot = self.next_data_plot - data_plot["unit_ids"] = unit_ids - data_plot["plot_all_units"] = True - data_plot["plot_legend"] = True - data_plot["hide_axis"] = True - - backend_kwargs = {} - backend_kwargs["ax"] = self.ax - - self.mpl_plotter.do_plot(data_plot, **backend_kwargs) - fig = self.ax.get_figure() - fig.canvas.draw() - fig.canvas.flush_events() diff --git a/src/spikeinterface/widgets/ipywidgets/spikes_on_traces.py b/src/spikeinterface/widgets/ipywidgets/spikes_on_traces.py deleted file mode 100644 index e5a3ebcc71..0000000000 --- a/src/spikeinterface/widgets/ipywidgets/spikes_on_traces.py +++ /dev/null @@ -1,145 +0,0 @@ -import numpy as np - -import matplotlib.pyplot as plt -import ipywidgets.widgets as widgets - - -from .base_ipywidgets import IpywidgetsPlotter -from .timeseries import TimeseriesPlotter -from .utils import make_unit_controller - -from ..spikes_on_traces import SpikesOnTracesWidget -from ..matplotlib.spikes_on_traces import SpikesOnTracesPlotter as MplSpikesOnTracesPlotter - -from IPython.display import display - - -class SpikesOnTracesPlotter(IpywidgetsPlotter): - def do_plot(self, data_plot, **backend_kwargs): - ratios = [0.2, 0.8] - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - backend_kwargs_ts = backend_kwargs.copy() - backend_kwargs_ts["width_cm"] = ratios[1] * backend_kwargs_ts["width_cm"] - backend_kwargs_ts["display"] = False - height_cm = backend_kwargs["height_cm"] - width_cm = backend_kwargs["width_cm"] - - # plot timeseries - tsplotter = TimeseriesPlotter() - data_plot["timeseries"]["add_legend"] = False - tsplotter.do_plot(data_plot["timeseries"], **backend_kwargs_ts) - - ts_w = tsplotter.widget - ts_updater = tsplotter.updater - - we = data_plot["waveform_extractor"] - unit_widget, unit_controller = make_unit_controller( - data_plot["unit_ids"], we.unit_ids, ratios[0] * width_cm, height_cm - ) - - self.controller = ts_updater.controller - self.controller.update(unit_controller) - - mpl_plotter = MplSpikesOnTracesPlotter() - - self.updater = PlotUpdater(data_plot, mpl_plotter, ts_updater, self.controller) - for w in self.controller.values(): - w.observe(self.updater) - - self.widget = widgets.AppLayout(center=ts_w, left_sidebar=unit_widget, pane_widths=ratios + [0]) - - # a first update - self.updater(None) - - if backend_kwargs["display"]: - self.check_backend() - display(self.widget) - - -SpikesOnTracesPlotter.register(SpikesOnTracesWidget) - - -class PlotUpdater: - def __init__(self, data_plot, mpl_plotter, ts_updater, controller): - self.data_plot = data_plot - self.mpl_plotter = mpl_plotter - - self.ts_updater = ts_updater - self.ax = ts_updater.ax - self.fig = self.ax.figure - self.controller = controller - - def __call__(self, change): - self.ax.clear() - - unit_ids = self.controller["unit_ids"].value - - # update ts - # self.ts_updater.__call__(change) - - # update data plot - data_plot = self.data_plot.copy() - data_plot["timeseries"] = self.ts_updater.next_data_plot - data_plot["unit_ids"] = unit_ids - - backend_kwargs = {} - backend_kwargs["ax"] = self.ax - - self.mpl_plotter.do_plot(data_plot, **backend_kwargs) - - self.fig.canvas.draw() - self.fig.canvas.flush_events() - - # t = self.time_slider.value - # d = self.win_sizer.value - - # selected_layer = self.layer_selector.value - # segment_index = self.seg_selector.value - # mode = self.mode_selector.value - - # t_stop = self.t_stops[segment_index] - # if self.actual_segment_index != segment_index: - # # change time_slider limits - # self.time_slider.max = t_stop - # self.actual_segment_index = segment_index - - # # protect limits - # if t >= t_stop - d: - # t = t_stop - d - - # time_range = np.array([t, t+d]) - - # if mode =='line': - # # plot all layer - # layer_keys = self.data_plot['layer_keys'] - # recordings = self.recordings - # clims = None - # elif mode =='map': - # layer_keys = [selected_layer] - # recordings = {selected_layer: self.recordings[selected_layer]} - # clims = {selected_layer: self.data_plot["clims"][selected_layer]} - - # channel_ids = self.data_plot['channel_ids'] - # order = self.data_plot['order'] - # times, list_traces, frame_range, order = _get_trace_list(recordings, channel_ids, time_range, order, - # segment_index) - - # # matplotlib next_data_plot dict update at each call - # data_plot = self.next_data_plot - # data_plot['mode'] = mode - # data_plot['frame_range'] = frame_range - # data_plot['time_range'] = time_range - # data_plot['with_colorbar'] = False - # data_plot['recordings'] = recordings - # data_plot['layer_keys'] = layer_keys - # data_plot['list_traces'] = list_traces - # data_plot['times'] = times - # data_plot['clims'] = clims - - # backend_kwargs = {} - # backend_kwargs['ax'] = self.ax - # self.mpl_plotter.do_plot(data_plot, **backend_kwargs) - - # fig = self.ax.figure - # fig.canvas.draw() - # fig.canvas.flush_events() diff --git a/src/spikeinterface/widgets/ipywidgets/template_metrics.py b/src/spikeinterface/widgets/ipywidgets/template_metrics.py deleted file mode 100644 index 0aea8ae428..0000000000 --- a/src/spikeinterface/widgets/ipywidgets/template_metrics.py +++ /dev/null @@ -1,9 +0,0 @@ -from ..template_metrics import TemplateMetricsWidget -from .metrics import MetricsPlotter - - -class TemplateMetricsPlotter(MetricsPlotter): - pass - - -TemplateMetricsPlotter.register(TemplateMetricsWidget) diff --git a/src/spikeinterface/widgets/ipywidgets/timeseries.py b/src/spikeinterface/widgets/ipywidgets/timeseries.py deleted file mode 100644 index 2448166f16..0000000000 --- a/src/spikeinterface/widgets/ipywidgets/timeseries.py +++ /dev/null @@ -1,232 +0,0 @@ -import numpy as np - -import matplotlib.pyplot as plt -import ipywidgets.widgets as widgets - -from ...core import order_channels_by_depth - -from .base_ipywidgets import IpywidgetsPlotter -from .utils import make_timeseries_controller, make_channel_controller, make_scale_controller - -from ..timeseries import TimeseriesWidget, _get_trace_list -from ..matplotlib.timeseries import TimeseriesPlotter as MplTimeseriesPlotter - -from IPython.display import display - - -class TimeseriesPlotter(IpywidgetsPlotter): - def do_plot(self, data_plot, **backend_kwargs): - recordings = data_plot["recordings"] - - # first layer - rec0 = recordings[data_plot["layer_keys"][0]] - - cm = 1 / 2.54 - - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - width_cm = backend_kwargs["width_cm"] - height_cm = backend_kwargs["height_cm"] - ratios = [0.1, 0.8, 0.2] - - with plt.ioff(): - output = widgets.Output() - with output: - fig, ax = plt.subplots(figsize=(0.9 * ratios[1] * width_cm * cm, height_cm * cm)) - plt.show() - - t_start = 0.0 - t_stop = rec0.get_num_samples(segment_index=0) / rec0.get_sampling_frequency() - - ts_widget, ts_controller = make_timeseries_controller( - t_start, - t_stop, - data_plot["layer_keys"], - rec0.get_num_segments(), - data_plot["time_range"], - data_plot["mode"], - False, - width_cm, - ) - - ch_widget, ch_controller = make_channel_controller(rec0, width_cm=ratios[2] * width_cm, height_cm=height_cm) - - scale_widget, scale_controller = make_scale_controller(width_cm=ratios[0] * width_cm, height_cm=height_cm) - - self.controller = ts_controller - self.controller.update(ch_controller) - self.controller.update(scale_controller) - - mpl_plotter = MplTimeseriesPlotter() - - self.updater = PlotUpdater(data_plot, mpl_plotter, ax, self.controller) - for w in self.controller.values(): - if isinstance(w, widgets.Button): - w.on_click(self.updater) - else: - w.observe(self.updater) - - self.widget = widgets.AppLayout( - center=fig.canvas, - footer=ts_widget, - left_sidebar=scale_widget, - right_sidebar=ch_widget, - pane_heights=[0, 6, 1], - pane_widths=ratios, - ) - - # a first update - self.updater(None) - - if backend_kwargs["display"]: - self.check_backend() - display(self.widget) - - -TimeseriesPlotter.register(TimeseriesWidget) - - -class PlotUpdater: - def __init__(self, data_plot, mpl_plotter, ax, controller): - self.data_plot = data_plot - self.mpl_plotter = mpl_plotter - - self.ax = ax - self.controller = controller - - self.recordings = data_plot["recordings"] - self.return_scaled = data_plot["return_scaled"] - self.next_data_plot = data_plot.copy() - self.list_traces = None - - self.actual_segment_index = self.controller["segment_index"].value - - self.rec0 = self.recordings[self.data_plot["layer_keys"][0]] - self.t_stops = [ - self.rec0.get_num_samples(segment_index=seg_index) / self.rec0.get_sampling_frequency() - for seg_index in range(self.rec0.get_num_segments()) - ] - - def __call__(self, change): - self.ax.clear() - - # if changing the layer_key, no need to retrieve and process traces - retrieve_traces = True - scale_up = False - scale_down = False - if change is not None: - for cname, c in self.controller.items(): - if isinstance(change, dict): - if change["owner"] is c and cname == "layer_key": - retrieve_traces = False - elif isinstance(change, widgets.Button): - if change is c and cname == "plus": - scale_up = True - if change is c and cname == "minus": - scale_down = True - - t_start = self.controller["t_start"].value - window = self.controller["window"].value - layer_key = self.controller["layer_key"].value - segment_index = self.controller["segment_index"].value - mode = self.controller["mode"].value - chan_start, chan_stop = self.controller["channel_inds"].value - - if mode == "line": - self.controller["all_layers"].layout.visibility = "visible" - all_layers = self.controller["all_layers"].value - elif mode == "map": - self.controller["all_layers"].layout.visibility = "hidden" - all_layers = False - - if all_layers: - self.controller["layer_key"].layout.visibility = "hidden" - else: - self.controller["layer_key"].layout.visibility = "visible" - - if chan_start == chan_stop: - chan_stop += 1 - channel_indices = np.arange(chan_start, chan_stop) - - t_stop = self.t_stops[segment_index] - if self.actual_segment_index != segment_index: - # change time_slider limits - self.controller["t_start"].max = t_stop - self.actual_segment_index = segment_index - - # protect limits - if t_start >= t_stop - window: - t_start = t_stop - window - - time_range = np.array([t_start, t_start + window]) - data_plot = self.next_data_plot - - if retrieve_traces: - all_channel_ids = self.recordings[list(self.recordings.keys())[0]].channel_ids - if self.data_plot["order"] is not None: - all_channel_ids = all_channel_ids[self.data_plot["order"]] - channel_ids = all_channel_ids[channel_indices] - if self.data_plot["order_channel_by_depth"]: - order, _ = order_channels_by_depth(self.rec0, channel_ids) - else: - order = None - times, list_traces, frame_range, channel_ids = _get_trace_list( - self.recordings, channel_ids, time_range, segment_index, order, self.return_scaled - ) - self.list_traces = list_traces - else: - times = data_plot["times"] - list_traces = data_plot["list_traces"] - frame_range = data_plot["frame_range"] - channel_ids = data_plot["channel_ids"] - - if all_layers: - layer_keys = self.data_plot["layer_keys"] - recordings = self.recordings - list_traces_plot = self.list_traces - else: - layer_keys = [layer_key] - recordings = {layer_key: self.recordings[layer_key]} - list_traces_plot = [self.list_traces[list(self.recordings.keys()).index(layer_key)]] - - if scale_up: - if mode == "line": - data_plot["vspacing"] *= 0.8 - elif mode == "map": - data_plot["clims"] = { - layer: (1.2 * val[0], 1.2 * val[1]) for layer, val in self.data_plot["clims"].items() - } - if scale_down: - if mode == "line": - data_plot["vspacing"] *= 1.2 - elif mode == "map": - data_plot["clims"] = { - layer: (0.8 * val[0], 0.8 * val[1]) for layer, val in self.data_plot["clims"].items() - } - - self.next_data_plot["vspacing"] = data_plot["vspacing"] - self.next_data_plot["clims"] = data_plot["clims"] - - if mode == "line": - clims = None - elif mode == "map": - clims = {layer_key: self.data_plot["clims"][layer_key]} - - # matplotlib next_data_plot dict update at each call - data_plot["mode"] = mode - data_plot["frame_range"] = frame_range - data_plot["time_range"] = time_range - data_plot["with_colorbar"] = False - data_plot["recordings"] = recordings - data_plot["layer_keys"] = layer_keys - data_plot["list_traces"] = list_traces_plot - data_plot["times"] = times - data_plot["clims"] = clims - data_plot["channel_ids"] = channel_ids - - backend_kwargs = {} - backend_kwargs["ax"] = self.ax - self.mpl_plotter.do_plot(data_plot, **backend_kwargs) - - fig = self.ax.figure - fig.canvas.draw() - fig.canvas.flush_events() diff --git a/src/spikeinterface/widgets/ipywidgets/unit_locations.py b/src/spikeinterface/widgets/ipywidgets/unit_locations.py deleted file mode 100644 index e78c0d8fe5..0000000000 --- a/src/spikeinterface/widgets/ipywidgets/unit_locations.py +++ /dev/null @@ -1,91 +0,0 @@ -import numpy as np - -import matplotlib.pyplot as plt -import ipywidgets.widgets as widgets - - -from ..base import to_attr - -from .base_ipywidgets import IpywidgetsPlotter -from .utils import make_unit_controller - -from ..unit_locations import UnitLocationsWidget -from ..matplotlib.unit_locations import UnitLocationsPlotter as MplUnitLocationsPlotter - -from IPython.display import display - - -class UnitLocationsPlotter(IpywidgetsPlotter): - def do_plot(self, data_plot, **backend_kwargs): - cm = 1 / 2.54 - - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - width_cm = backend_kwargs["width_cm"] - height_cm = backend_kwargs["height_cm"] - - ratios = [0.15, 0.85] - - with plt.ioff(): - output = widgets.Output() - with output: - fig, ax = plt.subplots(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) - plt.show() - - data_plot["unit_ids"] = data_plot["unit_ids"][:1] - unit_widget, unit_controller = make_unit_controller( - data_plot["unit_ids"], list(data_plot["unit_colors"].keys()), ratios[0] * width_cm, height_cm - ) - - self.controller = unit_controller - - mpl_plotter = MplUnitLocationsPlotter() - - self.updater = PlotUpdater(data_plot, mpl_plotter, ax, self.controller) - for w in self.controller.values(): - w.observe(self.updater) - - self.widget = widgets.AppLayout( - center=fig.canvas, - left_sidebar=unit_widget, - pane_widths=ratios + [0], - ) - - # a first update - self.updater(None) - - if backend_kwargs["display"]: - self.check_backend() - display(self.widget) - - -UnitLocationsPlotter.register(UnitLocationsWidget) - - -class PlotUpdater: - def __init__(self, data_plot, mpl_plotter, ax, controller): - self.data_plot = data_plot - self.mpl_plotter = mpl_plotter - self.ax = ax - self.controller = controller - - self.next_data_plot = data_plot.copy() - - def __call__(self, change): - self.ax.clear() - - unit_ids = self.controller["unit_ids"].value - - # matplotlib next_data_plot dict update at each call - data_plot = self.next_data_plot - data_plot["unit_ids"] = unit_ids - data_plot["plot_all_units"] = True - data_plot["plot_legend"] = True - data_plot["hide_axis"] = True - - backend_kwargs = {} - backend_kwargs["ax"] = self.ax - - self.mpl_plotter.do_plot(data_plot, **backend_kwargs) - fig = self.ax.get_figure() - fig.canvas.draw() - fig.canvas.flush_events() diff --git a/src/spikeinterface/widgets/ipywidgets/unit_templates.py b/src/spikeinterface/widgets/ipywidgets/unit_templates.py deleted file mode 100644 index 41da9d8cd3..0000000000 --- a/src/spikeinterface/widgets/ipywidgets/unit_templates.py +++ /dev/null @@ -1,11 +0,0 @@ -from ..unit_templates import UnitTemplatesWidget -from .unit_waveforms import UnitWaveformPlotter - - -class UnitTemplatesPlotter(UnitWaveformPlotter): - def do_plot(self, data_plot, **backend_kwargs): - super().do_plot(data_plot, **backend_kwargs) - self.controller["plot_templates"].layout.visibility = "hidden" - - -UnitTemplatesPlotter.register(UnitTemplatesWidget) diff --git a/src/spikeinterface/widgets/ipywidgets/unit_waveforms.py b/src/spikeinterface/widgets/ipywidgets/unit_waveforms.py deleted file mode 100644 index 012b46038a..0000000000 --- a/src/spikeinterface/widgets/ipywidgets/unit_waveforms.py +++ /dev/null @@ -1,169 +0,0 @@ -import numpy as np - -import matplotlib.pyplot as plt -import ipywidgets.widgets as widgets - - -from ..base import to_attr - -from .base_ipywidgets import IpywidgetsPlotter -from .utils import make_unit_controller - -from ..unit_waveforms import UnitWaveformsWidget -from ..matplotlib.unit_waveforms import UnitWaveformPlotter as MplUnitWaveformPlotter - -from IPython.display import display - - -class UnitWaveformPlotter(IpywidgetsPlotter): - def do_plot(self, data_plot, **backend_kwargs): - cm = 1 / 2.54 - we = data_plot["waveform_extractor"] - - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - width_cm = backend_kwargs["width_cm"] - height_cm = backend_kwargs["height_cm"] - - ratios = [0.1, 0.7, 0.2] - - with plt.ioff(): - output1 = widgets.Output() - with output1: - fig_wf = plt.figure(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) - plt.show() - output2 = widgets.Output() - with output2: - fig_probe, ax_probe = plt.subplots(figsize=((ratios[2] * width_cm) * cm, height_cm * cm)) - plt.show() - - data_plot["unit_ids"] = data_plot["unit_ids"][:1] - unit_widget, unit_controller = make_unit_controller( - data_plot["unit_ids"], we.unit_ids, ratios[0] * width_cm, height_cm - ) - - same_axis_button = widgets.Checkbox( - value=False, - description="same axis", - disabled=False, - ) - - plot_templates_button = widgets.Checkbox( - value=True, - description="plot templates", - disabled=False, - ) - - hide_axis_button = widgets.Checkbox( - value=True, - description="hide axis", - disabled=False, - ) - - footer = widgets.HBox([same_axis_button, plot_templates_button, hide_axis_button]) - - self.controller = { - "same_axis": same_axis_button, - "plot_templates": plot_templates_button, - "hide_axis": hide_axis_button, - } - self.controller.update(unit_controller) - - mpl_plotter = MplUnitWaveformPlotter() - - self.updater = PlotUpdater(data_plot, mpl_plotter, fig_wf, ax_probe, self.controller) - for w in self.controller.values(): - w.observe(self.updater) - - self.widget = widgets.AppLayout( - center=fig_wf.canvas, - left_sidebar=unit_widget, - right_sidebar=fig_probe.canvas, - pane_widths=ratios, - footer=footer, - ) - - # a first update - self.updater(None) - - if backend_kwargs["display"]: - self.check_backend() - display(self.widget) - - -UnitWaveformPlotter.register(UnitWaveformsWidget) - - -class PlotUpdater: - def __init__(self, data_plot, mpl_plotter, fig_wf, ax_probe, controller): - self.data_plot = data_plot - self.mpl_plotter = mpl_plotter - self.fig_wf = fig_wf - self.ax_probe = ax_probe - self.controller = controller - - self.we = data_plot["waveform_extractor"] - self.next_data_plot = data_plot.copy() - - def __call__(self, change): - self.fig_wf.clear() - self.ax_probe.clear() - - unit_ids = self.controller["unit_ids"].value - same_axis = self.controller["same_axis"].value - plot_templates = self.controller["plot_templates"].value - hide_axis = self.controller["hide_axis"].value - - # matplotlib next_data_plot dict update at each call - data_plot = self.next_data_plot - data_plot["unit_ids"] = unit_ids - data_plot["templates"] = self.we.get_all_templates(unit_ids=unit_ids) - data_plot["template_stds"] = self.we.get_all_templates(unit_ids=unit_ids, mode="std") - data_plot["same_axis"] = same_axis - data_plot["plot_templates"] = plot_templates - if data_plot["plot_waveforms"]: - data_plot["wfs_by_ids"] = {unit_id: self.we.get_waveforms(unit_id) for unit_id in unit_ids} - - backend_kwargs = {} - - if same_axis: - backend_kwargs["ax"] = self.fig_wf.add_subplot() - data_plot["set_title"] = False - else: - backend_kwargs["figure"] = self.fig_wf - - self.mpl_plotter.do_plot(data_plot, **backend_kwargs) - if same_axis: - self.mpl_plotter.ax.axis("equal") - if hide_axis: - self.mpl_plotter.ax.axis("off") - else: - if hide_axis: - for i in range(len(unit_ids)): - ax = self.mpl_plotter.axes.flatten()[i] - ax.axis("off") - - # update probe plot - channel_locations = self.we.get_channel_locations() - self.ax_probe.plot( - channel_locations[:, 0], channel_locations[:, 1], ls="", marker="o", color="gray", markersize=2, alpha=0.5 - ) - self.ax_probe.axis("off") - self.ax_probe.axis("equal") - - for unit in unit_ids: - channel_inds = data_plot["sparsity"].unit_id_to_channel_indices[unit] - self.ax_probe.plot( - channel_locations[channel_inds, 0], - channel_locations[channel_inds, 1], - ls="", - marker="o", - markersize=3, - color=self.next_data_plot["unit_colors"][unit], - ) - self.ax_probe.set_xlim(np.min(channel_locations[:, 0]) - 10, np.max(channel_locations[:, 0]) + 10) - fig_probe = self.ax_probe.get_figure() - - self.fig_wf.canvas.draw() - self.fig_wf.canvas.flush_events() - fig_probe.canvas.draw() - fig_probe.canvas.flush_events() diff --git a/src/spikeinterface/widgets/ipywidgets/utils.py b/src/spikeinterface/widgets/ipywidgets/utils.py deleted file mode 100644 index f4b86c3fc2..0000000000 --- a/src/spikeinterface/widgets/ipywidgets/utils.py +++ /dev/null @@ -1,97 +0,0 @@ -import ipywidgets.widgets as widgets -import numpy as np - - -def make_timeseries_controller(t_start, t_stop, layer_keys, num_segments, time_range, mode, all_layers, width_cm): - time_slider = widgets.FloatSlider( - orientation="horizontal", - description="time:", - value=time_range[0], - min=t_start, - max=t_stop, - continuous_update=False, - layout=widgets.Layout(width=f"{width_cm}cm"), - ) - layer_selector = widgets.Dropdown(description="layer", options=layer_keys) - segment_selector = widgets.Dropdown(description="segment", options=list(range(num_segments))) - window_sizer = widgets.BoundedFloatText(value=np.diff(time_range)[0], step=0.1, min=0.005, description="win (s)") - mode_selector = widgets.Dropdown(options=["line", "map"], description="mode", value=mode) - all_layers = widgets.Checkbox(description="plot all layers", value=all_layers) - - controller = { - "layer_key": layer_selector, - "segment_index": segment_selector, - "window": window_sizer, - "t_start": time_slider, - "mode": mode_selector, - "all_layers": all_layers, - } - widget = widgets.VBox( - [time_slider, widgets.HBox([all_layers, layer_selector, segment_selector, window_sizer, mode_selector])] - ) - - return widget, controller - - -def make_unit_controller(unit_ids, all_unit_ids, width_cm, height_cm): - unit_label = widgets.Label(value="units:") - - unit_selector = widgets.SelectMultiple( - options=all_unit_ids, - value=list(unit_ids), - disabled=False, - layout=widgets.Layout(width=f"{width_cm}cm", height=f"{height_cm}cm"), - ) - - controller = {"unit_ids": unit_selector} - widget = widgets.VBox([unit_label, unit_selector]) - - return widget, controller - - -def make_channel_controller(recording, width_cm, height_cm): - channel_label = widgets.Label("channel indices:", layout=widgets.Layout(justify_content="center")) - channel_selector = widgets.IntRangeSlider( - value=[0, recording.get_num_channels()], - min=0, - max=recording.get_num_channels(), - step=1, - disabled=False, - continuous_update=False, - orientation="vertical", - readout=True, - readout_format="d", - layout=widgets.Layout(width=f"{0.8 * width_cm}cm", height=f"{height_cm}cm"), - ) - - controller = {"channel_inds": channel_selector} - widget = widgets.VBox([channel_label, channel_selector]) - - return widget, controller - - -def make_scale_controller(width_cm, height_cm): - scale_label = widgets.Label("Scale", layout=widgets.Layout(justify_content="center")) - - plus_selector = widgets.Button( - description="", - disabled=False, - button_style="", # 'success', 'info', 'warning', 'danger' or '' - tooltip="Increase scale", - icon="arrow-up", - layout=widgets.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), - ) - - minus_selector = widgets.Button( - description="", - disabled=False, - button_style="", # 'success', 'info', 'warning', 'danger' or '' - tooltip="Decrease scale", - icon="arrow-down", - layout=widgets.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), - ) - - controller = {"plus": plus_selector, "minus": minus_selector} - widget = widgets.VBox([scale_label, plus_selector, minus_selector]) - - return widget, controller diff --git a/src/spikeinterface/widgets/matplotlib/__init__.py b/src/spikeinterface/widgets/matplotlib/__init__.py deleted file mode 100644 index 525396e30d..0000000000 --- a/src/spikeinterface/widgets/matplotlib/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -from .all_amplitudes_distributions import AllAmplitudesDistributionsPlotter -from .amplitudes import AmplitudesPlotter -from .autocorrelograms import AutoCorrelogramsPlotter -from .crosscorrelograms import CrossCorrelogramsPlotter -from .quality_metrics import QualityMetricsPlotter -from .motion import MotionPlotter -from .spike_locations import SpikeLocationsPlotter -from .spikes_on_traces import SpikesOnTracesPlotter -from .template_metrics import TemplateMetricsPlotter -from .template_similarity import TemplateSimilarityPlotter -from .timeseries import TimeseriesPlotter -from .unit_locations import UnitLocationsPlotter -from .unit_templates import UnitTemplatesWidget -from .unit_waveforms import UnitWaveformPlotter -from .unit_waveforms_density_map import UnitWaveformDensityMapPlotter -from .unit_depths import UnitDepthsPlotter -from .unit_summary import UnitSummaryPlotter diff --git a/src/spikeinterface/widgets/matplotlib/all_amplitudes_distributions.py b/src/spikeinterface/widgets/matplotlib/all_amplitudes_distributions.py deleted file mode 100644 index 6985d2167a..0000000000 --- a/src/spikeinterface/widgets/matplotlib/all_amplitudes_distributions.py +++ /dev/null @@ -1,41 +0,0 @@ -import numpy as np - -from ..base import to_attr -from ..all_amplitudes_distributions import AllAmplitudesDistributionsWidget -from .base_mpl import MplPlotter - - -class AllAmplitudesDistributionsPlotter(MplPlotter): - def do_plot(self, data_plot, **backend_kwargs): - dp = to_attr(data_plot) - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - self.make_mpl_figure(**backend_kwargs) - - ax = self.ax - - unit_amps = [] - for i, unit_id in enumerate(dp.unit_ids): - amps = [] - for segment_index in range(dp.num_segments): - amps.append(dp.amplitudes[segment_index][unit_id]) - amps = np.concatenate(amps) - unit_amps.append(amps) - parts = ax.violinplot(unit_amps, showmeans=False, showmedians=False, showextrema=False) - - for i, pc in enumerate(parts["bodies"]): - color = dp.unit_colors[dp.unit_ids[i]] - pc.set_facecolor(color) - pc.set_edgecolor("black") - pc.set_alpha(1) - - ax.set_xticks(np.arange(len(dp.unit_ids)) + 1) - ax.set_xticklabels([str(unit_id) for unit_id in dp.unit_ids]) - - ylims = ax.get_ylim() - if np.max(ylims) < 0: - ax.set_ylim(min(ylims), 0) - if np.min(ylims) > 0: - ax.set_ylim(0, max(ylims)) - - -AllAmplitudesDistributionsPlotter.register(AllAmplitudesDistributionsWidget) diff --git a/src/spikeinterface/widgets/matplotlib/amplitudes.py b/src/spikeinterface/widgets/matplotlib/amplitudes.py deleted file mode 100644 index 747709211a..0000000000 --- a/src/spikeinterface/widgets/matplotlib/amplitudes.py +++ /dev/null @@ -1,69 +0,0 @@ -import numpy as np - -from ..base import to_attr -from ..amplitudes import AmplitudesWidget -from .base_mpl import MplPlotter - - -class AmplitudesPlotter(MplPlotter): - def __init__(self) -> None: - self.legend = None - - def do_plot(self, data_plot, **backend_kwargs): - dp = to_attr(data_plot) - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - - if backend_kwargs["axes"] is not None: - axes = backend_kwargs["axes"] - if dp.plot_histograms: - assert np.asarray(axes).size == 2 - else: - assert np.asarray(axes).size == 1 - elif backend_kwargs["ax"] is not None: - assert not dp.plot_histograms - else: - if dp.plot_histograms: - backend_kwargs["num_axes"] = 2 - backend_kwargs["ncols"] = 2 - else: - backend_kwargs["num_axes"] = None - - self.make_mpl_figure(**backend_kwargs) - - scatter_ax = self.axes.flatten()[0] - - for unit_id in dp.unit_ids: - spiketrains = dp.spiketrains[unit_id] - amps = dp.amplitudes[unit_id] - scatter_ax.scatter(spiketrains, amps, color=dp.unit_colors[unit_id], s=3, alpha=1, label=unit_id) - - if dp.plot_histograms: - if dp.bins is None: - bins = int(len(spiketrains) / 30) - else: - bins = dp.bins - ax_hist = self.axes.flatten()[1] - ax_hist.hist(amps, bins=bins, orientation="horizontal", color=dp.unit_colors[unit_id], alpha=0.8) - - if dp.plot_histograms: - ax_hist = self.axes.flatten()[1] - ax_hist.set_ylim(scatter_ax.get_ylim()) - ax_hist.axis("off") - self.figure.tight_layout() - - if dp.plot_legend: - if self.legend is not None: - self.legend.remove() - self.legend = self.figure.legend( - loc="upper center", bbox_to_anchor=(0.5, 1.0), ncol=5, fancybox=True, shadow=True - ) - - scatter_ax.set_xlim(0, dp.total_duration) - scatter_ax.set_xlabel("Times [s]") - scatter_ax.set_ylabel(f"Amplitude") - scatter_ax.spines["top"].set_visible(False) - scatter_ax.spines["right"].set_visible(False) - self.figure.subplots_adjust(bottom=0.1, top=0.9, left=0.1) - - -AmplitudesPlotter.register(AmplitudesWidget) diff --git a/src/spikeinterface/widgets/matplotlib/autocorrelograms.py b/src/spikeinterface/widgets/matplotlib/autocorrelograms.py deleted file mode 100644 index 9245ef6881..0000000000 --- a/src/spikeinterface/widgets/matplotlib/autocorrelograms.py +++ /dev/null @@ -1,30 +0,0 @@ -from ..base import to_attr -from ..autocorrelograms import AutoCorrelogramsWidget -from .base_mpl import MplPlotter - - -class AutoCorrelogramsPlotter(MplPlotter): - def do_plot(self, data_plot, **backend_kwargs): - dp = to_attr(data_plot) - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - backend_kwargs["num_axes"] = len(dp.unit_ids) - - self.make_mpl_figure(**backend_kwargs) - - bins = dp.bins - unit_ids = dp.unit_ids - correlograms = dp.correlograms - bin_width = bins[1] - bins[0] - - for i, unit_id in enumerate(unit_ids): - ccg = correlograms[i, i] - ax = self.axes.flatten()[i] - if dp.unit_colors is None: - color = "g" - else: - color = dp.unit_colors[unit_id] - ax.bar(x=bins[:-1], height=ccg, width=bin_width, color=color, align="edge") - ax.set_title(str(unit_id)) - - -AutoCorrelogramsPlotter.register(AutoCorrelogramsWidget) diff --git a/src/spikeinterface/widgets/matplotlib/base_mpl.py b/src/spikeinterface/widgets/matplotlib/base_mpl.py deleted file mode 100644 index 266adc8782..0000000000 --- a/src/spikeinterface/widgets/matplotlib/base_mpl.py +++ /dev/null @@ -1,102 +0,0 @@ -from spikeinterface.widgets.base import BackendPlotter - -import matplotlib.pyplot as plt -import numpy as np - - -class MplPlotter(BackendPlotter): - backend = "matplotlib" - backend_kwargs_desc = { - "figure": "Matplotlib figure. When None, it is created. Default None", - "ax": "Single matplotlib axis. When None, it is created. Default None", - "axes": "Multiple matplotlib axes. When None, they is created. Default None", - "ncols": "Number of columns to create in subplots. Default 5", - "figsize": "Size of matplotlib figure. Default None", - "figtitle": "The figure title. Default None", - } - default_backend_kwargs = {"figure": None, "ax": None, "axes": None, "ncols": 5, "figsize": None, "figtitle": None} - - def make_mpl_figure(self, figure=None, ax=None, axes=None, ncols=None, num_axes=None, figsize=None, figtitle=None): - """ - figure/ax/axes : only one of then can be not None - """ - if figure is not None: - assert ax is None and axes is None, "figure/ax/axes : only one of then can be not None" - if num_axes is None: - ax = figure.add_subplot(111) - axes = np.array([[ax]]) - else: - assert ncols is not None - axes = [] - nrows = int(np.ceil(num_axes / ncols)) - axes = np.full((nrows, ncols), fill_value=None, dtype=object) - for i in range(num_axes): - ax = figure.add_subplot(nrows, ncols, i + 1) - r = i // ncols - c = i % ncols - axes[r, c] = ax - elif ax is not None: - assert figure is None and axes is None, "figure/ax/axes : only one of then can be not None" - figure = ax.get_figure() - axes = np.array([[ax]]) - elif axes is not None: - assert figure is None and ax is None, "figure/ax/axes : only one of then can be not None" - axes = np.asarray(axes) - figure = axes.flatten()[0].get_figure() - else: - # 'figure/ax/axes are all None - if num_axes is None: - # one fig with one ax - figure, ax = plt.subplots(figsize=figsize) - axes = np.array([[ax]]) - else: - if num_axes == 0: - # one figure without plots (diffred subplot creation with - figure = plt.figure(figsize=figsize) - ax = None - axes = None - elif num_axes == 1: - figure = plt.figure(figsize=figsize) - ax = figure.add_subplot(111) - axes = np.array([[ax]]) - else: - assert ncols is not None - if num_axes < ncols: - ncols = num_axes - nrows = int(np.ceil(num_axes / ncols)) - figure, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize) - ax = None - # remove extra axes - if ncols * nrows > num_axes: - for i, extra_ax in enumerate(axes.flatten()): - if i >= num_axes: - extra_ax.remove() - r = i // ncols - c = i % ncols - axes[r, c] = None - - self.figure = figure - self.ax = ax - # axes is always a 2D array of ax - self.axes = axes - - if figtitle is not None: - self.figure.suptitle(figtitle) - - -class to_attr(object): - def __init__(self, d): - """ - Helper function that transform a dict into - an object where attributes are the keys of the dict - - d = {'a': 1, 'b': 'yep'} - o = to_attr(d) - print(o.a, o.b) - """ - object.__init__(self) - object.__setattr__(self, "__d", d) - - def __getattribute__(self, k): - d = object.__getattribute__(self, "__d") - return d[k] diff --git a/src/spikeinterface/widgets/matplotlib/crosscorrelograms.py b/src/spikeinterface/widgets/matplotlib/crosscorrelograms.py deleted file mode 100644 index 24ecdcdffc..0000000000 --- a/src/spikeinterface/widgets/matplotlib/crosscorrelograms.py +++ /dev/null @@ -1,39 +0,0 @@ -from ..base import to_attr -from ..crosscorrelograms import CrossCorrelogramsWidget -from .base_mpl import MplPlotter - - -class CrossCorrelogramsPlotter(MplPlotter): - def do_plot(self, data_plot, **backend_kwargs): - dp = to_attr(data_plot) - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - backend_kwargs["ncols"] = len(dp.unit_ids) - backend_kwargs["num_axes"] = int(len(dp.unit_ids) ** 2) - - self.make_mpl_figure(**backend_kwargs) - assert self.axes.ndim == 2 - - bins = dp.bins - unit_ids = dp.unit_ids - correlograms = dp.correlograms - bin_width = bins[1] - bins[0] - - for i, unit_id1 in enumerate(unit_ids): - for j, unit_id2 in enumerate(unit_ids): - ccg = correlograms[i, j] - ax = self.axes[i, j] - if i == j: - if dp.unit_colors is None: - color = "g" - else: - color = dp.unit_colors[unit_id1] - else: - color = "k" - ax.bar(x=bins[:-1], height=ccg, width=bin_width, color=color, align="edge") - - for i, unit_id in enumerate(unit_ids): - self.axes[0, i].set_title(str(unit_id)) - self.axes[-1, i].set_xlabel("CCG (ms)") - - -CrossCorrelogramsPlotter.register(CrossCorrelogramsWidget) diff --git a/src/spikeinterface/widgets/matplotlib/metrics.py b/src/spikeinterface/widgets/matplotlib/metrics.py deleted file mode 100644 index cec4c11644..0000000000 --- a/src/spikeinterface/widgets/matplotlib/metrics.py +++ /dev/null @@ -1,50 +0,0 @@ -import numpy as np - -from ..base import to_attr -from .base_mpl import MplPlotter - - -class MetricsPlotter(MplPlotter): - def do_plot(self, data_plot, **backend_kwargs): - dp = to_attr(data_plot) - metrics = dp.metrics - num_metrics = len(metrics.columns) - - if "figsize" not in backend_kwargs: - backend_kwargs["figsize"] = (2 * num_metrics, 2 * num_metrics) - - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - backend_kwargs["num_axes"] = num_metrics**2 - backend_kwargs["ncols"] = num_metrics - - all_unit_ids = metrics.index.values - - self.make_mpl_figure(**backend_kwargs) - assert self.axes.ndim == 2 - - if dp.unit_ids is None: - colors = ["gray"] * len(all_unit_ids) - else: - colors = [] - for unit in all_unit_ids: - color = "gray" if unit not in dp.unit_ids else dp.unit_colors[unit] - colors.append(color) - - self.patches = [] - for i, m1 in enumerate(metrics.columns): - for j, m2 in enumerate(metrics.columns): - if i == j: - self.axes[i, j].hist(metrics[m1], color="gray") - else: - p = self.axes[i, j].scatter(metrics[m1], metrics[m2], c=colors, s=3, marker="o") - self.patches.append(p) - if i == num_metrics - 1: - self.axes[i, j].set_xlabel(m2, fontsize=10) - if j == 0: - self.axes[i, j].set_ylabel(m1, fontsize=10) - self.axes[i, j].set_xticklabels([]) - self.axes[i, j].set_yticklabels([]) - self.axes[i, j].spines["top"].set_visible(False) - self.axes[i, j].spines["right"].set_visible(False) - - self.figure.subplots_adjust(top=0.8, wspace=0.2, hspace=0.2) diff --git a/src/spikeinterface/widgets/matplotlib/motion.py b/src/spikeinterface/widgets/matplotlib/motion.py deleted file mode 100644 index 8a89351c8a..0000000000 --- a/src/spikeinterface/widgets/matplotlib/motion.py +++ /dev/null @@ -1,129 +0,0 @@ -from ..base import to_attr -from ..motion import MotionWidget -from .base_mpl import MplPlotter - -import numpy as np -from matplotlib.colors import Normalize - - -class MotionPlotter(MplPlotter): - def do_plot(self, data_plot, **backend_kwargs): - import matplotlib.pyplot as plt - from spikeinterface.sortingcomponents.motion_interpolation import correct_motion_on_peaks - - dp = to_attr(data_plot) - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - - assert backend_kwargs["axes"] is None - assert backend_kwargs["ax"] is None - - self.make_mpl_figure(**backend_kwargs) - fig = self.figure - fig.clear() - - is_rigid = dp.motion.shape[1] == 1 - - gs = fig.add_gridspec(2, 2, wspace=0.3, hspace=0.3) - ax0 = fig.add_subplot(gs[0, 0]) - ax1 = fig.add_subplot(gs[0, 1]) - ax2 = fig.add_subplot(gs[1, 0]) - if not is_rigid: - ax3 = fig.add_subplot(gs[1, 1]) - ax1.sharex(ax0) - ax1.sharey(ax0) - - if dp.motion_lim is None: - motion_lim = np.max(np.abs(dp.motion)) * 1.05 - else: - motion_lim = dp.motion_lim - - if dp.times is None: - temporal_bins_plot = dp.temporal_bins - x = dp.peaks["sample_index"] / dp.sampling_frequency - else: - # use real times and adjust temporal bins with t_start - temporal_bins_plot = dp.temporal_bins + dp.times[0] - x = dp.times[dp.peaks["sample_index"]] - - corrected_location = correct_motion_on_peaks( - dp.peaks, - dp.peak_locations, - dp.sampling_frequency, - dp.motion, - dp.temporal_bins, - dp.spatial_bins, - direction="y", - ) - - y = dp.peak_locations["y"] - y2 = corrected_location["y"] - if dp.scatter_decimate is not None: - x = x[:: dp.scatter_decimate] - y = y[:: dp.scatter_decimate] - y2 = y2[:: dp.scatter_decimate] - - if dp.color_amplitude: - amps = dp.peaks["amplitude"] - amps_abs = np.abs(amps) - q_95 = np.quantile(amps_abs, 0.95) - if dp.scatter_decimate is not None: - amps = amps[:: dp.scatter_decimate] - amps_abs = amps_abs[:: dp.scatter_decimate] - cmap = plt.get_cmap(dp.amplitude_cmap) - if dp.amplitude_clim is None: - amps = amps_abs - amps /= q_95 - c = cmap(amps) - else: - norm_function = Normalize(vmin=dp.amplitude_clim[0], vmax=dp.amplitude_clim[1], clip=True) - c = cmap(norm_function(amps)) - color_kwargs = dict( - color=None, - c=c, - alpha=dp.amplitude_alpha, - ) - else: - color_kwargs = dict(color="k", c=None, alpha=dp.amplitude_alpha) - - ax0.scatter(x, y, s=1, **color_kwargs) - if dp.depth_lim is not None: - ax0.set_ylim(*dp.depth_lim) - ax0.set_title("Peak depth") - ax0.set_xlabel("Times [s]") - ax0.set_ylabel("Depth [um]") - - ax1.scatter(x, y2, s=1, **color_kwargs) - ax1.set_xlabel("Times [s]") - ax1.set_ylabel("Depth [um]") - ax1.set_title("Corrected peak depth") - - ax2.plot(temporal_bins_plot, dp.motion, alpha=0.2, color="black") - ax2.plot(temporal_bins_plot, np.mean(dp.motion, axis=1), color="C0") - ax2.set_ylim(-motion_lim, motion_lim) - ax2.set_ylabel("Motion [um]") - ax2.set_title("Motion vectors") - axes = [ax0, ax1, ax2] - - if not is_rigid: - im = ax3.imshow( - dp.motion.T, - aspect="auto", - origin="lower", - extent=( - temporal_bins_plot[0], - temporal_bins_plot[-1], - dp.spatial_bins[0], - dp.spatial_bins[-1], - ), - ) - im.set_clim(-motion_lim, motion_lim) - cbar = fig.colorbar(im) - cbar.ax.set_xlabel("motion [um]") - ax3.set_xlabel("Times [s]") - ax3.set_ylabel("Depth [um]") - ax3.set_title("Motion vectors") - axes.append(ax3) - self.axes = np.array(axes) - - -MotionPlotter.register(MotionWidget) diff --git a/src/spikeinterface/widgets/matplotlib/quality_metrics.py b/src/spikeinterface/widgets/matplotlib/quality_metrics.py deleted file mode 100644 index 3fc368770b..0000000000 --- a/src/spikeinterface/widgets/matplotlib/quality_metrics.py +++ /dev/null @@ -1,9 +0,0 @@ -from ..quality_metrics import QualityMetricsWidget -from .metrics import MetricsPlotter - - -class QualityMetricsPlotter(MetricsPlotter): - pass - - -QualityMetricsPlotter.register(QualityMetricsWidget) diff --git a/src/spikeinterface/widgets/matplotlib/spike_locations.py b/src/spikeinterface/widgets/matplotlib/spike_locations.py deleted file mode 100644 index 5c74df3fc8..0000000000 --- a/src/spikeinterface/widgets/matplotlib/spike_locations.py +++ /dev/null @@ -1,96 +0,0 @@ -from probeinterface import ProbeGroup -from probeinterface.plotting import plot_probe - -import numpy as np - -from ..base import to_attr -from ..spike_locations import SpikeLocationsWidget, estimate_axis_lims -from .base_mpl import MplPlotter - -from matplotlib.patches import Ellipse -from matplotlib.lines import Line2D - - -class SpikeLocationsPlotter(MplPlotter): - def __init__(self) -> None: - self.legend = None - - def do_plot(self, data_plot, **backend_kwargs): - dp = to_attr(data_plot) - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - - self.make_mpl_figure(**backend_kwargs) - - spike_locations = dp.spike_locations - - probegroup = ProbeGroup.from_dict(dp.probegroup_dict) - probe_shape_kwargs = dict(facecolor="w", edgecolor="k", lw=0.5, alpha=1.0) - contacts_kargs = dict(alpha=1.0, edgecolor="k", lw=0.5) - - for probe in probegroup.probes: - text_on_contact = None - if dp.with_channel_ids: - text_on_contact = dp.channel_ids - - poly_contact, poly_contour = plot_probe( - probe, - ax=self.ax, - contacts_colors="w", - contacts_kargs=contacts_kargs, - probe_shape_kwargs=probe_shape_kwargs, - text_on_contact=text_on_contact, - ) - poly_contact.set_zorder(2) - if poly_contour is not None: - poly_contour.set_zorder(1) - - self.ax.set_title("") - - if dp.plot_all_units: - unit_colors = {} - unit_ids = dp.all_unit_ids - for unit in dp.all_unit_ids: - if unit not in dp.unit_ids: - unit_colors[unit] = "gray" - else: - unit_colors[unit] = dp.unit_colors[unit] - else: - unit_ids = dp.unit_ids - unit_colors = dp.unit_colors - labels = dp.unit_ids - - for i, unit in enumerate(unit_ids): - locs = spike_locations[unit] - - zorder = 5 if unit in dp.unit_ids else 3 - self.ax.scatter(locs["x"], locs["y"], s=2, alpha=0.3, color=unit_colors[unit], zorder=zorder) - - handles = [ - Line2D([0], [0], ls="", marker="o", markersize=5, markeredgewidth=2, color=unit_colors[unit]) - for unit in dp.unit_ids - ] - if dp.plot_legend: - if self.legend is not None: - self.legend.remove() - self.legend = self.figure.legend( - handles, labels, loc="upper center", bbox_to_anchor=(0.5, 1.0), ncol=5, fancybox=True, shadow=True - ) - - # set proper axis limits - xlims, ylims = estimate_axis_lims(spike_locations) - - ax_xlims = list(self.ax.get_xlim()) - ax_ylims = list(self.ax.get_ylim()) - - ax_xlims[0] = xlims[0] if xlims[0] < ax_xlims[0] else ax_xlims[0] - ax_xlims[1] = xlims[1] if xlims[1] > ax_xlims[1] else ax_xlims[1] - ax_ylims[0] = ylims[0] if ylims[0] < ax_ylims[0] else ax_ylims[0] - ax_ylims[1] = ylims[1] if ylims[1] > ax_ylims[1] else ax_ylims[1] - - self.ax.set_xlim(ax_xlims) - self.ax.set_ylim(ax_ylims) - if dp.hide_axis: - self.ax.axis("off") - - -SpikeLocationsPlotter.register(SpikeLocationsWidget) diff --git a/src/spikeinterface/widgets/matplotlib/spikes_on_traces.py b/src/spikeinterface/widgets/matplotlib/spikes_on_traces.py deleted file mode 100644 index d620c8f28f..0000000000 --- a/src/spikeinterface/widgets/matplotlib/spikes_on_traces.py +++ /dev/null @@ -1,104 +0,0 @@ -import numpy as np - -from ..base import to_attr -from ..spikes_on_traces import SpikesOnTracesWidget -from .base_mpl import MplPlotter -from .timeseries import TimeseriesPlotter - -from matplotlib.patches import Ellipse -from matplotlib.lines import Line2D - - -class SpikesOnTracesPlotter(MplPlotter): - def do_plot(self, data_plot, **backend_kwargs): - dp = to_attr(data_plot) - - # first plot time series - tsplotter = TimeseriesPlotter() - data_plot["timeseries"]["add_legend"] = False - tsplotter.do_plot(dp.timeseries, **backend_kwargs) - self.ax = tsplotter.ax - self.axes = tsplotter.axes - self.figure = tsplotter.figure - - ax = self.ax - - we = dp.waveform_extractor - sorting = dp.waveform_extractor.sorting - frame_range = dp.timeseries["frame_range"] - segment_index = dp.timeseries["segment_index"] - min_y = np.min(dp.timeseries["channel_locations"][:, 1]) - max_y = np.max(dp.timeseries["channel_locations"][:, 1]) - - n = len(dp.timeseries["channel_ids"]) - order = dp.timeseries["order"] - if order is None: - order = np.arange(n) - - if ax.get_legend() is not None: - ax.get_legend().remove() - - # loop through units and plot a scatter of spikes at estimated location - handles = [] - labels = [] - - for unit in dp.unit_ids: - spike_frames = sorting.get_unit_spike_train(unit, segment_index=segment_index) - spike_start, spike_end = np.searchsorted(spike_frames, frame_range) - - chan_ids = dp.sparsity.unit_id_to_channel_ids[unit] - - spike_frames_to_plot = spike_frames[spike_start:spike_end] - - if dp.timeseries["mode"] == "map": - spike_times_to_plot = sorting.get_unit_spike_train( - unit, segment_index=segment_index, return_times=True - )[spike_start:spike_end] - unit_y_loc = min_y + max_y - dp.unit_locations[unit][1] - # markers = np.ones_like(spike_frames_to_plot) * (min_y + max_y - dp.unit_locations[unit][1]) - width = 2 * 1e-3 - ellipse_kwargs = dict(width=width, height=10, fc="none", ec=dp.unit_colors[unit], lw=2) - patches = [Ellipse((s, unit_y_loc), **ellipse_kwargs) for s in spike_times_to_plot] - for p in patches: - ax.add_patch(p) - handles.append( - Line2D( - [0], - [0], - ls="", - marker="o", - markersize=5, - markeredgewidth=2, - markeredgecolor=dp.unit_colors[unit], - markerfacecolor="none", - ) - ) - labels.append(unit) - else: - # construct waveforms - label_set = False - if len(spike_frames_to_plot) > 0: - vspacing = dp.timeseries["vspacing"] - traces = dp.timeseries["list_traces"][0] - waveform_idxs = spike_frames_to_plot[:, None] + np.arange(-we.nbefore, we.nafter) - frame_range[0] - waveform_idxs = np.clip(waveform_idxs, 0, len(dp.timeseries["times"]) - 1) - - times = dp.timeseries["times"][waveform_idxs] - # discontinuity - times[:, -1] = np.nan - times_r = times.reshape(times.shape[0] * times.shape[1]) - waveforms = traces[waveform_idxs] # [:, :, order] - waveforms_r = waveforms.reshape((waveforms.shape[0] * waveforms.shape[1], waveforms.shape[2])) - - for i, chan_id in enumerate(dp.timeseries["channel_ids"]): - offset = vspacing * i - if chan_id in chan_ids: - l = ax.plot(times_r, offset + waveforms_r[:, i], color=dp.unit_colors[unit]) - if not label_set: - handles.append(l[0]) - labels.append(unit) - label_set = True - ax.legend(handles, labels) - - -SpikesOnTracesPlotter.register(SpikesOnTracesWidget) diff --git a/src/spikeinterface/widgets/matplotlib/template_metrics.py b/src/spikeinterface/widgets/matplotlib/template_metrics.py deleted file mode 100644 index 0aea8ae428..0000000000 --- a/src/spikeinterface/widgets/matplotlib/template_metrics.py +++ /dev/null @@ -1,9 +0,0 @@ -from ..template_metrics import TemplateMetricsWidget -from .metrics import MetricsPlotter - - -class TemplateMetricsPlotter(MetricsPlotter): - pass - - -TemplateMetricsPlotter.register(TemplateMetricsWidget) diff --git a/src/spikeinterface/widgets/matplotlib/template_similarity.py b/src/spikeinterface/widgets/matplotlib/template_similarity.py deleted file mode 100644 index 1e0a2e6fae..0000000000 --- a/src/spikeinterface/widgets/matplotlib/template_similarity.py +++ /dev/null @@ -1,30 +0,0 @@ -import numpy as np - -from ..base import to_attr -from ..template_similarity import TemplateSimilarityWidget -from .base_mpl import MplPlotter - - -class TemplateSimilarityPlotter(MplPlotter): - def do_plot(self, data_plot, **backend_kwargs): - dp = to_attr(data_plot) - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - - self.make_mpl_figure(**backend_kwargs) - - im = self.ax.matshow(dp.similarity, cmap=dp.cmap) - - if dp.show_unit_ticks: - # Major ticks - self.ax.set_xticks(np.arange(0, len(dp.unit_ids))) - self.ax.set_yticks(np.arange(0, len(dp.unit_ids))) - self.ax.xaxis.tick_bottom() - - # Labels for major ticks - self.ax.set_yticklabels(dp.unit_ids, fontsize=12) - self.ax.set_xticklabels(dp.unit_ids, fontsize=12) - if dp.show_colorbar: - self.figure.colorbar(im) - - -TemplateSimilarityPlotter.register(TemplateSimilarityWidget) diff --git a/src/spikeinterface/widgets/matplotlib/timeseries.py b/src/spikeinterface/widgets/matplotlib/timeseries.py deleted file mode 100644 index 0a887b559f..0000000000 --- a/src/spikeinterface/widgets/matplotlib/timeseries.py +++ /dev/null @@ -1,70 +0,0 @@ -import numpy as np - -from ..base import to_attr -from ..timeseries import TimeseriesWidget -from .base_mpl import MplPlotter -from matplotlib.ticker import MaxNLocator - - -class TimeseriesPlotter(MplPlotter): - def do_plot(self, data_plot, **backend_kwargs): - dp = to_attr(data_plot) - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - - self.make_mpl_figure(**backend_kwargs) - ax = self.ax - n = len(dp.channel_ids) - if dp.channel_locations is not None: - y_locs = dp.channel_locations[:, 1] - else: - y_locs = np.arange(n) - min_y = np.min(y_locs) - max_y = np.max(y_locs) - - if dp.mode == "line": - offset = dp.vspacing * (n - 1) - - for layer_key, traces in zip(dp.layer_keys, dp.list_traces): - for i, chan_id in enumerate(dp.channel_ids): - offset = dp.vspacing * i - color = dp.colors[layer_key][chan_id] - ax.plot(dp.times, offset + traces[:, i], color=color) - ax.get_lines()[-1].set_label(layer_key) - - if dp.show_channel_ids: - ax.set_yticks(np.arange(n) * dp.vspacing) - channel_labels = np.array([str(chan_id) for chan_id in dp.channel_ids]) - ax.set_yticklabels(channel_labels) - else: - ax.get_yaxis().set_visible(False) - - ax.set_xlim(*dp.time_range) - ax.set_ylim(-dp.vspacing, dp.vspacing * n) - ax.get_xaxis().set_major_locator(MaxNLocator(prune="both")) - ax.set_xlabel("time (s)") - if dp.add_legend: - ax.legend(loc="upper right") - - elif dp.mode == "map": - assert len(dp.list_traces) == 1, 'plot_timeseries with mode="map" do not support multi recording' - assert len(dp.clims) == 1 - clim = list(dp.clims.values())[0] - extent = (dp.time_range[0], dp.time_range[1], min_y, max_y) - im = ax.imshow( - dp.list_traces[0].T, interpolation="nearest", origin="lower", aspect="auto", extent=extent, cmap=dp.cmap - ) - - im.set_clim(*clim) - - if dp.with_colorbar: - self.figure.colorbar(im, ax=ax) - - if dp.show_channel_ids: - ax.set_yticks(np.linspace(min_y, max_y, n) + (max_y - min_y) / n * 0.5) - channel_labels = np.array([str(chan_id) for chan_id in dp.channel_ids]) - ax.set_yticklabels(channel_labels) - else: - ax.get_yaxis().set_visible(False) - - -TimeseriesPlotter.register(TimeseriesWidget) diff --git a/src/spikeinterface/widgets/matplotlib/unit_depths.py b/src/spikeinterface/widgets/matplotlib/unit_depths.py deleted file mode 100644 index aa16ff3578..0000000000 --- a/src/spikeinterface/widgets/matplotlib/unit_depths.py +++ /dev/null @@ -1,22 +0,0 @@ -from ..base import to_attr -from ..unit_depths import UnitDepthsWidget -from .base_mpl import MplPlotter - - -class UnitDepthsPlotter(MplPlotter): - def do_plot(self, data_plot, **backend_kwargs): - dp = to_attr(data_plot) - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - self.make_mpl_figure(**backend_kwargs) - - ax = self.ax - size = dp.num_spikes / max(dp.num_spikes) * 120 - ax.scatter(dp.unit_amplitudes, dp.unit_depths, color=dp.colors, s=size) - - ax.set_aspect(3) - ax.set_xlabel("amplitude") - ax.set_ylabel("depth [um]") - ax.set_xlim(0, max(dp.unit_amplitudes) * 1.2) - - -UnitDepthsPlotter.register(UnitDepthsWidget) diff --git a/src/spikeinterface/widgets/matplotlib/unit_locations.py b/src/spikeinterface/widgets/matplotlib/unit_locations.py deleted file mode 100644 index 6f084c0aec..0000000000 --- a/src/spikeinterface/widgets/matplotlib/unit_locations.py +++ /dev/null @@ -1,95 +0,0 @@ -from probeinterface import ProbeGroup -from probeinterface.plotting import plot_probe - -import numpy as np -from spikeinterface.core import waveform_extractor - -from ..base import to_attr -from ..unit_locations import UnitLocationsWidget -from .base_mpl import MplPlotter - -from matplotlib.patches import Ellipse -from matplotlib.lines import Line2D - - -class UnitLocationsPlotter(MplPlotter): - def __init__(self) -> None: - self.legend = None - - def do_plot(self, data_plot, **backend_kwargs): - dp = to_attr(data_plot) - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - - self.make_mpl_figure(**backend_kwargs) - - unit_locations = dp.unit_locations - - probegroup = ProbeGroup.from_dict(dp.probegroup_dict) - probe_shape_kwargs = dict(facecolor="w", edgecolor="k", lw=0.5, alpha=1.0) - contacts_kargs = dict(alpha=1.0, edgecolor="k", lw=0.5) - - for probe in probegroup.probes: - text_on_contact = None - if dp.with_channel_ids: - text_on_contact = dp.channel_ids - - poly_contact, poly_contour = plot_probe( - probe, - ax=self.ax, - contacts_colors="w", - contacts_kargs=contacts_kargs, - probe_shape_kwargs=probe_shape_kwargs, - text_on_contact=text_on_contact, - ) - poly_contact.set_zorder(2) - if poly_contour is not None: - poly_contour.set_zorder(1) - - self.ax.set_title("") - - # color = np.array([dp.unit_colors[unit_id] for unit_id in dp.unit_ids]) - width = height = 10 - ellipse_kwargs = dict(width=width, height=height, lw=2) - - if dp.plot_all_units: - unit_colors = {} - unit_ids = dp.all_unit_ids - for unit in dp.all_unit_ids: - if unit not in dp.unit_ids: - unit_colors[unit] = "gray" - else: - unit_colors[unit] = dp.unit_colors[unit] - else: - unit_ids = dp.unit_ids - unit_colors = dp.unit_colors - labels = dp.unit_ids - - patches = [ - Ellipse( - (unit_locations[unit]), - color=unit_colors[unit], - zorder=5 if unit in dp.unit_ids else 3, - alpha=0.9 if unit in dp.unit_ids else 0.5, - **ellipse_kwargs, - ) - for i, unit in enumerate(unit_ids) - ] - for p in patches: - self.ax.add_patch(p) - handles = [ - Line2D([0], [0], ls="", marker="o", markersize=5, markeredgewidth=2, color=unit_colors[unit]) - for unit in dp.unit_ids - ] - - if dp.plot_legend: - if self.legend is not None: - self.legend.remove() - self.legend = self.figure.legend( - handles, labels, loc="upper center", bbox_to_anchor=(0.5, 1.0), ncol=5, fancybox=True, shadow=True - ) - - if dp.hide_axis: - self.ax.axis("off") - - -UnitLocationsPlotter.register(UnitLocationsWidget) diff --git a/src/spikeinterface/widgets/matplotlib/unit_summary.py b/src/spikeinterface/widgets/matplotlib/unit_summary.py deleted file mode 100644 index 5327afa25e..0000000000 --- a/src/spikeinterface/widgets/matplotlib/unit_summary.py +++ /dev/null @@ -1,73 +0,0 @@ -import numpy as np - -from ..base import to_attr -from ..unit_summary import UnitSummaryWidget -from .base_mpl import MplPlotter - - -from .unit_locations import UnitLocationsPlotter -from .amplitudes import AmplitudesPlotter -from .unit_waveforms import UnitWaveformPlotter -from .unit_waveforms_density_map import UnitWaveformDensityMapPlotter - -from .autocorrelograms import AutoCorrelogramsPlotter - - -class UnitSummaryPlotter(MplPlotter): - def do_plot(self, data_plot, **backend_kwargs): - dp = to_attr(data_plot) - - # force the figure without axes - if "figsize" not in backend_kwargs: - backend_kwargs["figsize"] = (18, 7) - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - backend_kwargs["num_axes"] = 0 - backend_kwargs["ax"] = None - backend_kwargs["axes"] = None - - self.make_mpl_figure(**backend_kwargs) - - # and use custum grid spec - fig = self.figure - nrows = 2 - ncols = 3 - if dp.plot_data_acc is not None or dp.plot_data_amplitudes is not None: - ncols += 1 - if dp.plot_data_amplitudes is not None: - nrows += 1 - gs = fig.add_gridspec(nrows, ncols) - - if dp.plot_data_unit_locations is not None: - ax1 = fig.add_subplot(gs[:2, 0]) - UnitLocationsPlotter().do_plot(dp.plot_data_unit_locations, ax=ax1) - x, y = dp.unit_location[0], dp.unit_location[1] - ax1.set_xlim(x - 80, x + 80) - ax1.set_ylim(y - 250, y + 250) - ax1.set_xticks([]) - ax1.set_xlabel(None) - ax1.set_ylabel(None) - - ax2 = fig.add_subplot(gs[:2, 1]) - UnitWaveformPlotter().do_plot(dp.plot_data_waveforms, ax=ax2) - ax2.set_title(None) - - ax3 = fig.add_subplot(gs[:2, 2]) - UnitWaveformDensityMapPlotter().do_plot(dp.plot_data_waveform_density, ax=ax3) - ax3.set_ylabel(None) - - if dp.plot_data_acc is not None: - ax4 = fig.add_subplot(gs[:2, 3]) - AutoCorrelogramsPlotter().do_plot(dp.plot_data_acc, ax=ax4) - ax4.set_title(None) - ax4.set_yticks([]) - - if dp.plot_data_amplitudes is not None: - ax5 = fig.add_subplot(gs[2, :3]) - ax6 = fig.add_subplot(gs[2, 3]) - axes = np.array([ax5, ax6]) - AmplitudesPlotter().do_plot(dp.plot_data_amplitudes, axes=axes) - - fig.suptitle(f"unit_id: {dp.unit_id}") - - -UnitSummaryPlotter.register(UnitSummaryWidget) diff --git a/src/spikeinterface/widgets/matplotlib/unit_templates.py b/src/spikeinterface/widgets/matplotlib/unit_templates.py deleted file mode 100644 index c1ce085bf2..0000000000 --- a/src/spikeinterface/widgets/matplotlib/unit_templates.py +++ /dev/null @@ -1,9 +0,0 @@ -from ..unit_templates import UnitTemplatesWidget -from .unit_waveforms import UnitWaveformPlotter - - -class UnitTemplatesPlotter(UnitWaveformPlotter): - pass - - -UnitTemplatesPlotter.register(UnitTemplatesWidget) diff --git a/src/spikeinterface/widgets/matplotlib/unit_waveforms.py b/src/spikeinterface/widgets/matplotlib/unit_waveforms.py deleted file mode 100644 index f499954918..0000000000 --- a/src/spikeinterface/widgets/matplotlib/unit_waveforms.py +++ /dev/null @@ -1,95 +0,0 @@ -import numpy as np - -from ..base import to_attr -from ..unit_waveforms import UnitWaveformsWidget -from .base_mpl import MplPlotter - - -class UnitWaveformPlotter(MplPlotter): - def __init__(self) -> None: - self.legend = None - - def do_plot(self, data_plot, **backend_kwargs): - dp = to_attr(data_plot) - - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - - if backend_kwargs["axes"] is not None: - assert len(backend_kwargs["axes"]) >= len(dp.unit_ids), "Provide as many 'axes' as neurons" - elif backend_kwargs["ax"] is not None: - assert dp.same_axis, "If 'same_axis' is not used, provide as many 'axes' as neurons" - else: - if dp.same_axis: - backend_kwargs["num_axes"] = 1 - backend_kwargs["ncols"] = None - else: - backend_kwargs["num_axes"] = len(dp.unit_ids) - backend_kwargs["ncols"] = min(dp.ncols, len(dp.unit_ids)) - - self.make_mpl_figure(**backend_kwargs) - - for i, unit_id in enumerate(dp.unit_ids): - if dp.same_axis: - ax = self.ax - else: - ax = self.axes.flatten()[i] - color = dp.unit_colors[unit_id] - - chan_inds = dp.sparsity.unit_id_to_channel_indices[unit_id] - xvectors_flat = dp.xvectors[:, chan_inds].T.flatten() - - # plot waveforms - if dp.plot_waveforms: - wfs = dp.wfs_by_ids[unit_id] - if dp.unit_selected_waveforms is not None: - wfs = wfs[dp.unit_selected_waveforms[unit_id]] - elif dp.max_spikes_per_unit is not None: - if len(wfs) > dp.max_spikes_per_unit: - random_idxs = np.random.permutation(len(wfs))[: dp.max_spikes_per_unit] - wfs = wfs[random_idxs] - wfs = wfs * dp.y_scale + dp.y_offset[None, :, chan_inds] - wfs_flat = wfs.swapaxes(1, 2).reshape(wfs.shape[0], -1).T - - if dp.x_offset_units: - # 0.7 is to match spacing in xvect - xvec = xvectors_flat + i * 0.7 * dp.delta_x - else: - xvec = xvectors_flat - - ax.plot(xvec, wfs_flat, lw=dp.lw_waveforms, alpha=dp.alpha_waveforms, color=color) - - if not dp.plot_templates: - ax.get_lines()[-1].set_label(f"{unit_id}") - - # plot template - if dp.plot_templates: - template = dp.templates[i, :, :][:, chan_inds] * dp.y_scale + dp.y_offset[:, chan_inds] - - if dp.x_offset_units: - # 0.7 is to match spacing in xvect - xvec = xvectors_flat + i * 0.7 * dp.delta_x - else: - xvec = xvectors_flat - - ax.plot( - xvec, template.T.flatten(), lw=dp.lw_templates, alpha=dp.alpha_templates, color=color, label=unit_id - ) - - template_label = dp.unit_ids[i] - if dp.set_title: - ax.set_title(f"template {template_label}") - - # plot channels - if dp.plot_channels: - # TODO enhance this - ax.scatter(dp.channel_locations[:, 0], dp.channel_locations[:, 1], color="k") - - if dp.same_axis and dp.plot_legend: - if self.legend is not None: - self.legend.remove() - self.legend = self.figure.legend( - loc="upper center", bbox_to_anchor=(0.5, 1.0), ncol=5, fancybox=True, shadow=True - ) - - -UnitWaveformPlotter.register(UnitWaveformsWidget) diff --git a/src/spikeinterface/widgets/matplotlib/unit_waveforms_density_map.py b/src/spikeinterface/widgets/matplotlib/unit_waveforms_density_map.py deleted file mode 100644 index ff9c1ec91b..0000000000 --- a/src/spikeinterface/widgets/matplotlib/unit_waveforms_density_map.py +++ /dev/null @@ -1,77 +0,0 @@ -import numpy as np - -from ..base import to_attr -from ..unit_waveforms_density_map import UnitWaveformDensityMapWidget -from .base_mpl import MplPlotter - - -class UnitWaveformDensityMapPlotter(MplPlotter): - def do_plot(self, data_plot, **backend_kwargs): - dp = to_attr(data_plot) - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - - if backend_kwargs["axes"] is not None or backend_kwargs["ax"] is not None: - self.make_mpl_figure(**backend_kwargs) - else: - if dp.same_axis: - num_axes = 1 - else: - num_axes = len(dp.unit_ids) - backend_kwargs["ncols"] = 1 - backend_kwargs["num_axes"] = num_axes - self.make_mpl_figure(**backend_kwargs) - - if dp.same_axis: - ax = self.ax - hist2d = dp.all_hist2d - im = ax.imshow( - hist2d.T, - interpolation="nearest", - origin="lower", - aspect="auto", - extent=(0, hist2d.shape[0], dp.bin_min, dp.bin_max), - cmap="hot", - ) - else: - for unit_index, unit_id in enumerate(dp.unit_ids): - hist2d = dp.all_hist2d[unit_id] - ax = self.axes.flatten()[unit_index] - im = ax.imshow( - hist2d.T, - interpolation="nearest", - origin="lower", - aspect="auto", - extent=(0, hist2d.shape[0], dp.bin_min, dp.bin_max), - cmap="hot", - ) - - for unit_index, unit_id in enumerate(dp.unit_ids): - if dp.same_axis: - ax = self.ax - else: - ax = self.axes.flatten()[unit_index] - color = dp.unit_colors[unit_id] - ax.plot(dp.templates_flat[unit_id], color=color, lw=1) - - # final cosmetics - for unit_index, unit_id in enumerate(dp.unit_ids): - if dp.same_axis: - ax = self.ax - if unit_index != 0: - continue - else: - ax = self.axes.flatten()[unit_index] - chan_inds = dp.channel_inds[unit_id] - for i, chan_ind in enumerate(chan_inds): - if i != 0: - ax.axvline(i * dp.template_width, color="w", lw=3) - channel_id = dp.channel_ids[chan_ind] - x = i * dp.template_width + dp.template_width // 2 - y = (dp.bin_max + dp.bin_min) / 2.0 - ax.text(x, y, f"chan_id {channel_id}", color="w", ha="center", va="center") - - ax.set_xticks([]) - ax.set_ylabel(f"unit_id {unit_id}") - - -UnitWaveformDensityMapPlotter.register(UnitWaveformDensityMapWidget) diff --git a/src/spikeinterface/widgets/sortingview/__init__.py b/src/spikeinterface/widgets/sortingview/__init__.py deleted file mode 100644 index 5663f95078..0000000000 --- a/src/spikeinterface/widgets/sortingview/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -from .amplitudes import AmplitudesPlotter -from .autocorrelograms import AutoCorrelogramsPlotter -from .crosscorrelograms import CrossCorrelogramsPlotter -from .quality_metrics import QualityMetricsPlotter -from .sorting_summary import SortingSummaryPlotter -from .spike_locations import SortingviewPlotter -from .template_metrics import TemplateMetricsPlotter -from .template_similarity import TemplateSimilarityPlotter -from .timeseries import TimeseriesPlotter -from .unit_locations import UnitLocationsPlotter -from .unit_templates import UnitTemplatesPlotter diff --git a/src/spikeinterface/widgets/sortingview/amplitudes.py b/src/spikeinterface/widgets/sortingview/amplitudes.py deleted file mode 100644 index 8676ccd994..0000000000 --- a/src/spikeinterface/widgets/sortingview/amplitudes.py +++ /dev/null @@ -1,36 +0,0 @@ -import numpy as np - -from ..base import to_attr -from ..amplitudes import AmplitudesWidget -from .base_sortingview import SortingviewPlotter - - -class AmplitudesPlotter(SortingviewPlotter): - default_label = "SpikeInterface - Amplitudes" - - def do_plot(self, data_plot, **backend_kwargs): - import sortingview.views as vv - - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - dp = to_attr(data_plot) - - unit_ids = self.make_serializable(dp.unit_ids) - - sa_items = [ - vv.SpikeAmplitudesItem( - unit_id=u, - spike_times_sec=dp.spiketrains[u].astype("float32"), - spike_amplitudes=dp.amplitudes[u].astype("float32"), - ) - for u in unit_ids - ] - - v_spike_amplitudes = vv.SpikeAmplitudes( - start_time_sec=0, end_time_sec=dp.total_duration, plots=sa_items, hide_unit_selector=dp.hide_unit_selector - ) - - self.handle_display_and_url(v_spike_amplitudes, **backend_kwargs) - return v_spike_amplitudes - - -AmplitudesPlotter.register(AmplitudesWidget) diff --git a/src/spikeinterface/widgets/sortingview/autocorrelograms.py b/src/spikeinterface/widgets/sortingview/autocorrelograms.py deleted file mode 100644 index 345f8c2bdf..0000000000 --- a/src/spikeinterface/widgets/sortingview/autocorrelograms.py +++ /dev/null @@ -1,34 +0,0 @@ -from ..base import to_attr -from ..autocorrelograms import AutoCorrelogramsWidget -from .base_sortingview import SortingviewPlotter - - -class AutoCorrelogramsPlotter(SortingviewPlotter): - default_label = "SpikeInterface - Auto Correlograms" - - def do_plot(self, data_plot, **backend_kwargs): - import sortingview.views as vv - - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - dp = to_attr(data_plot) - unit_ids = self.make_serializable(dp.unit_ids) - - ac_items = [] - for i in range(len(unit_ids)): - for j in range(i, len(unit_ids)): - if i == j: - ac_items.append( - vv.AutocorrelogramItem( - unit_id=unit_ids[i], - bin_edges_sec=(dp.bins / 1000.0).astype("float32"), - bin_counts=dp.correlograms[i, j].astype("int32"), - ) - ) - - v_autocorrelograms = vv.Autocorrelograms(autocorrelograms=ac_items) - - self.handle_display_and_url(v_autocorrelograms, **backend_kwargs) - return v_autocorrelograms - - -AutoCorrelogramsPlotter.register(AutoCorrelogramsWidget) diff --git a/src/spikeinterface/widgets/sortingview/base_sortingview.py b/src/spikeinterface/widgets/sortingview/base_sortingview.py deleted file mode 100644 index c42da0fba3..0000000000 --- a/src/spikeinterface/widgets/sortingview/base_sortingview.py +++ /dev/null @@ -1,103 +0,0 @@ -import numpy as np - -from ...core.core_tools import check_json -from spikeinterface.widgets.base import BackendPlotter - - -class SortingviewPlotter(BackendPlotter): - backend = "sortingview" - backend_kwargs_desc = { - "generate_url": "If True, the figurl URL is generated and printed. Default True", - "display": "If True and in jupyter notebook/lab, the widget is displayed in the cell. Default True.", - "figlabel": "The figurl figure label. Default None", - "height": "The height of the sortingview View in jupyter. Default None", - } - default_backend_kwargs = {"generate_url": True, "display": True, "figlabel": None, "height": None} - - def __init__(self): - self.view = None - self.url = None - - def make_serializable(*args): - dict_to_serialize = {int(i): a for i, a in enumerate(args[1:])} - serializable_dict = check_json(dict_to_serialize) - returns = () - for i in range(len(args) - 1): - returns += (serializable_dict[str(i)],) - if len(returns) == 1: - returns = returns[0] - return returns - - @staticmethod - def is_notebook() -> bool: - try: - shell = get_ipython().__class__.__name__ - if shell == "ZMQInteractiveShell": - return True # Jupyter notebook or qtconsole - elif shell == "TerminalInteractiveShell": - return False # Terminal running IPython - else: - return False # Other type (?) - except NameError: - return False - - def handle_display_and_url(self, view, **backend_kwargs): - self.set_view(view) - if self.is_notebook() and backend_kwargs["display"]: - display(self.view.jupyter(height=backend_kwargs["height"])) - if backend_kwargs["generate_url"]: - figlabel = backend_kwargs.get("figlabel") - if figlabel is None: - figlabel = self.default_label - url = view.url(label=figlabel) - self.set_url(url) - print(url) - - # make view and url accessible by the plotter - def set_view(self, view): - self.view = view - - def set_url(self, url): - self.url = url - - -def generate_unit_table_view(sorting, unit_properties=None, similarity_scores=None): - import sortingview.views as vv - - if unit_properties is None: - ut_columns = [] - ut_rows = [vv.UnitsTableRow(unit_id=u, values={}) for u in sorting.unit_ids] - else: - ut_columns = [] - ut_rows = [] - values = {} - valid_unit_properties = [] - for prop_name in unit_properties: - property_values = sorting.get_property(prop_name) - # make dtype available - val0 = np.array(property_values[0]) - if val0.dtype.kind in ("i", "u"): - dtype = "int" - elif val0.dtype.kind in ("U", "S"): - dtype = "str" - elif val0.dtype.kind == "f": - dtype = "float" - elif val0.dtype.kind == "b": - dtype = "bool" - else: - print(f"Unsupported dtype {val0.dtype} for property {prop_name}. Skipping") - continue - ut_columns.append(vv.UnitsTableColumn(key=prop_name, label=prop_name, dtype=dtype)) - valid_unit_properties.append(prop_name) - - for ui, unit in enumerate(sorting.unit_ids): - for prop_name in valid_unit_properties: - property_values = sorting.get_property(prop_name) - val0 = property_values[0] - if np.isnan(property_values[ui]): - continue - values[prop_name] = property_values[ui] - ut_rows.append(vv.UnitsTableRow(unit_id=unit, values=check_json(values))) - - v_units_table = vv.UnitsTable(rows=ut_rows, columns=ut_columns, similarity_scores=similarity_scores) - return v_units_table diff --git a/src/spikeinterface/widgets/sortingview/crosscorrelograms.py b/src/spikeinterface/widgets/sortingview/crosscorrelograms.py deleted file mode 100644 index ec9c7bb16c..0000000000 --- a/src/spikeinterface/widgets/sortingview/crosscorrelograms.py +++ /dev/null @@ -1,37 +0,0 @@ -from ..base import to_attr -from ..crosscorrelograms import CrossCorrelogramsWidget -from .base_sortingview import SortingviewPlotter - - -class CrossCorrelogramsPlotter(SortingviewPlotter): - default_label = "SpikeInterface - Cross Correlograms" - - def do_plot(self, data_plot, **backend_kwargs): - import sortingview.views as vv - - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - dp = to_attr(data_plot) - - unit_ids = self.make_serializable(dp.unit_ids) - - cc_items = [] - for i in range(len(unit_ids)): - for j in range(i, len(unit_ids)): - cc_items.append( - vv.CrossCorrelogramItem( - unit_id1=unit_ids[i], - unit_id2=unit_ids[j], - bin_edges_sec=(dp.bins / 1000.0).astype("float32"), - bin_counts=dp.correlograms[i, j].astype("int32"), - ) - ) - - v_cross_correlograms = vv.CrossCorrelograms( - cross_correlograms=cc_items, hide_unit_selector=dp.hide_unit_selector - ) - - self.handle_display_and_url(v_cross_correlograms, **backend_kwargs) - return v_cross_correlograms - - -CrossCorrelogramsPlotter.register(CrossCorrelogramsWidget) diff --git a/src/spikeinterface/widgets/sortingview/metrics.py b/src/spikeinterface/widgets/sortingview/metrics.py deleted file mode 100644 index d46256739e..0000000000 --- a/src/spikeinterface/widgets/sortingview/metrics.py +++ /dev/null @@ -1,61 +0,0 @@ -import numpy as np - -from ...core.core_tools import check_json -from ..base import to_attr -from .base_sortingview import SortingviewPlotter, generate_unit_table_view - - -class MetricsPlotter(SortingviewPlotter): - def do_plot(self, data_plot, **backend_kwargs): - import sortingview.views as vv - - dp = to_attr(data_plot) - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - - metrics = dp.metrics - metric_names = list(metrics.columns) - - if dp.unit_ids is None: - unit_ids = metrics.index.values - else: - unit_ids = dp.unit_ids - unit_ids = self.make_serializable(unit_ids) - - metrics_sv = [] - for col in metric_names: - dtype = metrics.iloc[0][col].dtype - metric = vv.UnitMetricsGraphMetric(key=col, label=col, dtype=dtype.str) - metrics_sv.append(metric) - - units_m = [] - for unit_id in unit_ids: - values = check_json(metrics.loc[unit_id].to_dict()) - values_skip_nans = {} - for k, v in values.items(): - if np.isnan(v): - continue - values_skip_nans[k] = v - - units_m.append(vv.UnitMetricsGraphUnit(unit_id=unit_id, values=values_skip_nans)) - v_metrics = vv.UnitMetricsGraph(units=units_m, metrics=metrics_sv) - - if not dp.hide_unit_selector: - if dp.include_metrics_data: - # make a view of the sorting to add tmp properties - sorting_copy = dp.sorting.select_units(unit_ids=dp.sorting.unit_ids) - for col in metric_names: - if col not in sorting_copy.get_property_keys(): - sorting_copy.set_property(col, metrics[col].values) - # generate table with properties - v_units_table = generate_unit_table_view(sorting_copy, unit_properties=metric_names) - else: - v_units_table = generate_unit_table_view(dp.sorting) - - view = vv.Splitter( - direction="horizontal", item1=vv.LayoutItem(v_units_table), item2=vv.LayoutItem(v_metrics) - ) - else: - view = v_metrics - - self.handle_display_and_url(view, **backend_kwargs) - return view diff --git a/src/spikeinterface/widgets/sortingview/quality_metrics.py b/src/spikeinterface/widgets/sortingview/quality_metrics.py deleted file mode 100644 index 379ba158a5..0000000000 --- a/src/spikeinterface/widgets/sortingview/quality_metrics.py +++ /dev/null @@ -1,11 +0,0 @@ -from .metrics import MetricsPlotter -from ..quality_metrics import QualityMetricsWidget - - -class QualityMetricsPlotter(MetricsPlotter): - default_label = "SpikeInterface - Quality Metrics" - - pass - - -QualityMetricsPlotter.register(QualityMetricsWidget) diff --git a/src/spikeinterface/widgets/sortingview/sorting_summary.py b/src/spikeinterface/widgets/sortingview/sorting_summary.py deleted file mode 100644 index bb248e1691..0000000000 --- a/src/spikeinterface/widgets/sortingview/sorting_summary.py +++ /dev/null @@ -1,86 +0,0 @@ -from ..base import to_attr -from ..sorting_summary import SortingSummaryWidget -from .base_sortingview import SortingviewPlotter, generate_unit_table_view - -from .amplitudes import AmplitudesPlotter -from .autocorrelograms import AutoCorrelogramsPlotter -from .crosscorrelograms import CrossCorrelogramsPlotter -from .template_similarity import TemplateSimilarityPlotter -from .unit_locations import UnitLocationsPlotter -from .unit_templates import UnitTemplatesPlotter - - -class SortingSummaryPlotter(SortingviewPlotter): - default_label = "SpikeInterface - Sorting Summary" - - def do_plot(self, data_plot, **backend_kwargs): - import sortingview.views as vv - - dp = to_attr(data_plot) - - unit_ids = self.make_serializable(dp.unit_ids) - - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - - amplitudes_plotter = AmplitudesPlotter() - v_spike_amplitudes = amplitudes_plotter.do_plot( - dp.amplitudes, generate_url=False, display=False, backend="sortingview" - ) - template_plotter = UnitTemplatesPlotter() - v_average_waveforms = template_plotter.do_plot( - dp.templates, generate_url=False, display=False, backend="sortingview" - ) - xcorrelograms_plotter = CrossCorrelogramsPlotter() - v_cross_correlograms = xcorrelograms_plotter.do_plot( - dp.correlograms, generate_url=False, display=False, backend="sortingview" - ) - unitlocation_plotter = UnitLocationsPlotter() - v_unit_locations = unitlocation_plotter.do_plot( - dp.unit_locations, generate_url=False, display=False, backend="sortingview" - ) - # similarity - similarity_scores = [] - for i1, u1 in enumerate(unit_ids): - for i2, u2 in enumerate(unit_ids): - similarity_scores.append( - vv.UnitSimilarityScore( - unit_id1=u1, unit_id2=u2, similarity=dp.similarity["similarity"][i1, i2].astype("float32") - ) - ) - - # unit ids - v_units_table = generate_unit_table_view( - dp.waveform_extractor.sorting, dp.unit_table_properties, similarity_scores=similarity_scores - ) - - if dp.curation: - v_curation = vv.SortingCuration2(label_choices=dp.label_choices) - v1 = vv.Splitter(direction="vertical", item1=vv.LayoutItem(v_units_table), item2=vv.LayoutItem(v_curation)) - else: - v1 = v_units_table - v2 = vv.Splitter( - direction="horizontal", - item1=vv.LayoutItem(v_unit_locations, stretch=0.2), - item2=vv.LayoutItem( - vv.Splitter( - direction="horizontal", - item1=vv.LayoutItem(v_average_waveforms), - item2=vv.LayoutItem( - vv.Splitter( - direction="vertical", - item1=vv.LayoutItem(v_spike_amplitudes), - item2=vv.LayoutItem(v_cross_correlograms), - ) - ), - ) - ), - ) - - # assemble layout - v_summary = vv.Splitter(direction="horizontal", item1=vv.LayoutItem(v1), item2=vv.LayoutItem(v2)) - - self.handle_display_and_url(v_summary, **backend_kwargs) - return v_summary - - -SortingSummaryPlotter.register(SortingSummaryWidget) diff --git a/src/spikeinterface/widgets/sortingview/spike_locations.py b/src/spikeinterface/widgets/sortingview/spike_locations.py deleted file mode 100644 index 747c3df4e7..0000000000 --- a/src/spikeinterface/widgets/sortingview/spike_locations.py +++ /dev/null @@ -1,64 +0,0 @@ -from ..base import to_attr -from ..spike_locations import SpikeLocationsWidget, estimate_axis_lims -from .base_sortingview import SortingviewPlotter, generate_unit_table_view - - -class SpikeLocationsPlotter(SortingviewPlotter): - default_label = "SpikeInterface - Spike Locations" - - def do_plot(self, data_plot, **backend_kwargs): - import sortingview.views as vv - - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - dp = to_attr(data_plot) - spike_locations = dp.spike_locations - - # ensure serializable for sortingview - unit_ids, channel_ids = self.make_serializable(dp.unit_ids, dp.channel_ids) - - locations = {str(ch): dp.channel_locations[i_ch].astype("float32") for i_ch, ch in enumerate(channel_ids)} - xlims, ylims = estimate_axis_lims(spike_locations) - - unit_items = [] - for unit in unit_ids: - spike_times_sec = dp.sorting.get_unit_spike_train( - unit_id=unit, segment_index=dp.segment_index, return_times=True - ) - unit_items.append( - vv.SpikeLocationsItem( - unit_id=unit, - spike_times_sec=spike_times_sec.astype("float32"), - x_locations=spike_locations[unit]["x"].astype("float32"), - y_locations=spike_locations[unit]["y"].astype("float32"), - ) - ) - - v_spike_locations = vv.SpikeLocations( - units=unit_items, - hide_unit_selector=dp.hide_unit_selector, - x_range=xlims.astype("float32"), - y_range=ylims.astype("float32"), - channel_locations=locations, - disable_auto_rotate=True, - ) - - if not dp.hide_unit_selector: - v_units_table = generate_unit_table_view(dp.sorting) - - view = vv.Box( - direction="horizontal", - items=[ - vv.LayoutItem(v_units_table, max_size=150), - vv.LayoutItem(v_spike_locations), - ], - ) - else: - view = v_spike_locations - - self.set_view(view) - - self.handle_display_and_url(view, **backend_kwargs) - return view - - -SpikeLocationsPlotter.register(SpikeLocationsWidget) diff --git a/src/spikeinterface/widgets/sortingview/template_metrics.py b/src/spikeinterface/widgets/sortingview/template_metrics.py deleted file mode 100644 index 204bb8f377..0000000000 --- a/src/spikeinterface/widgets/sortingview/template_metrics.py +++ /dev/null @@ -1,11 +0,0 @@ -from .metrics import MetricsPlotter -from ..template_metrics import TemplateMetricsWidget - - -class TemplateMetricsPlotter(MetricsPlotter): - default_label = "SpikeInterface - Template Metrics" - - pass - - -TemplateMetricsPlotter.register(TemplateMetricsWidget) diff --git a/src/spikeinterface/widgets/sortingview/template_similarity.py b/src/spikeinterface/widgets/sortingview/template_similarity.py deleted file mode 100644 index e35b8c2e34..0000000000 --- a/src/spikeinterface/widgets/sortingview/template_similarity.py +++ /dev/null @@ -1,32 +0,0 @@ -from ..base import to_attr -from ..template_similarity import TemplateSimilarityWidget -from .base_sortingview import SortingviewPlotter - - -class TemplateSimilarityPlotter(SortingviewPlotter): - default_label = "SpikeInterface - Template Similarity" - - def do_plot(self, data_plot, **backend_kwargs): - import sortingview.views as vv - - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - dp = to_attr(data_plot) - - # ensure serializable for sortingview - unit_ids = self.make_serializable(dp.unit_ids) - - # similarity - ss_items = [] - for i1, u1 in enumerate(unit_ids): - for i2, u2 in enumerate(unit_ids): - ss_items.append( - vv.UnitSimilarityScore(unit_id1=u1, unit_id2=u2, similarity=dp.similarity[i1, i2].astype("float32")) - ) - - view = vv.UnitSimilarityMatrix(unit_ids=list(unit_ids), similarity_scores=ss_items) - - self.handle_display_and_url(view, **backend_kwargs) - return view - - -TemplateSimilarityPlotter.register(TemplateSimilarityWidget) diff --git a/src/spikeinterface/widgets/sortingview/timeseries.py b/src/spikeinterface/widgets/sortingview/timeseries.py deleted file mode 100644 index eec0e920e4..0000000000 --- a/src/spikeinterface/widgets/sortingview/timeseries.py +++ /dev/null @@ -1,54 +0,0 @@ -import numpy as np -import warnings - -from ..base import to_attr -from ..timeseries import TimeseriesWidget -from ..utils import array_to_image -from .base_sortingview import SortingviewPlotter - - -class TimeseriesPlotter(SortingviewPlotter): - default_label = "SpikeInterface - Timeseries" - - def do_plot(self, data_plot, **backend_kwargs): - import sortingview.views as vv - - try: - import pyvips - except ImportError: - raise ImportError("To use the timeseries in sorting view you need the pyvips package.") - - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - dp = to_attr(data_plot) - - assert dp.mode == "map", 'sortingview plot_timeseries is only mode="map"' - - if not dp.order_channel_by_depth: - warnings.warn( - "It is recommended to set 'order_channel_by_depth' to True " "when using the sortingview backend" - ) - - tiled_layers = [] - for layer_key, traces in zip(dp.layer_keys, dp.list_traces): - img = array_to_image( - traces, - clim=dp.clims[layer_key], - num_timepoints_per_row=dp.num_timepoints_per_row, - colormap=dp.cmap, - scalebar=True, - sampling_frequency=dp.recordings[layer_key].get_sampling_frequency(), - ) - - tiled_layers.append(vv.TiledImageLayer(layer_key, img)) - - view_ts = vv.TiledImage(tile_size=dp.tile_size, layers=tiled_layers) - - self.set_view(view_ts) - - # timeseries currently doesn't display on the jupyter backend - backend_kwargs["display"] = False - self.handle_display_and_url(view_ts, **backend_kwargs) - return view_ts - - -TimeseriesPlotter.register(TimeseriesWidget) diff --git a/src/spikeinterface/widgets/sortingview/unit_locations.py b/src/spikeinterface/widgets/sortingview/unit_locations.py deleted file mode 100644 index 368b45321f..0000000000 --- a/src/spikeinterface/widgets/sortingview/unit_locations.py +++ /dev/null @@ -1,44 +0,0 @@ -from ..base import to_attr -from ..unit_locations import UnitLocationsWidget -from .base_sortingview import SortingviewPlotter, generate_unit_table_view - - -class UnitLocationsPlotter(SortingviewPlotter): - default_label = "SpikeInterface - Unit Locations" - - def do_plot(self, data_plot, **backend_kwargs): - import sortingview.views as vv - - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - dp = to_attr(data_plot) - - # ensure serializable for sortingview - unit_ids, channel_ids = self.make_serializable(dp.unit_ids, dp.channel_ids) - - locations = {str(ch): dp.channel_locations[i_ch].astype("float32") for i_ch, ch in enumerate(channel_ids)} - - unit_items = [] - for unit_id in unit_ids: - unit_items.append( - vv.UnitLocationsItem( - unit_id=unit_id, x=float(dp.unit_locations[unit_id][0]), y=float(dp.unit_locations[unit_id][1]) - ) - ) - - v_unit_locations = vv.UnitLocations(units=unit_items, channel_locations=locations, disable_auto_rotate=True) - - if not dp.hide_unit_selector: - v_units_table = generate_unit_table_view(dp.sorting) - - view = vv.Box( - direction="horizontal", - items=[vv.LayoutItem(v_units_table, max_size=150), vv.LayoutItem(v_unit_locations)], - ) - else: - view = v_unit_locations - - self.handle_display_and_url(view, **backend_kwargs) - return view - - -UnitLocationsPlotter.register(UnitLocationsWidget) diff --git a/src/spikeinterface/widgets/sortingview/unit_templates.py b/src/spikeinterface/widgets/sortingview/unit_templates.py deleted file mode 100644 index 37595740fd..0000000000 --- a/src/spikeinterface/widgets/sortingview/unit_templates.py +++ /dev/null @@ -1,54 +0,0 @@ -from ..base import to_attr -from ..unit_templates import UnitTemplatesWidget -from .base_sortingview import SortingviewPlotter, generate_unit_table_view - - -class UnitTemplatesPlotter(SortingviewPlotter): - default_label = "SpikeInterface - Unit Templates" - - def do_plot(self, data_plot, **backend_kwargs): - import sortingview.views as vv - - dp = to_attr(data_plot) - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - - # ensure serializable for sortingview - unit_id_to_channel_ids = dp.sparsity.unit_id_to_channel_ids - unit_id_to_channel_indices = dp.sparsity.unit_id_to_channel_indices - - unit_ids, channel_ids = self.make_serializable(dp.unit_ids, dp.channel_ids) - - templates_dict = {} - for u_i, unit in enumerate(unit_ids): - templates_dict[unit] = {} - templates_dict[unit]["mean"] = dp.templates[u_i].T.astype("float32")[unit_id_to_channel_indices[unit]] - templates_dict[unit]["std"] = dp.template_stds[u_i].T.astype("float32")[unit_id_to_channel_indices[unit]] - - aw_items = [ - vv.AverageWaveformItem( - unit_id=u, - channel_ids=list(unit_id_to_channel_ids[u]), - waveform=t["mean"].astype("float32"), - waveform_std_dev=t["std"].astype("float32"), - ) - for u, t in templates_dict.items() - ] - - locations = {str(ch): dp.channel_locations[i_ch].astype("float32") for i_ch, ch in enumerate(channel_ids)} - v_average_waveforms = vv.AverageWaveforms(average_waveforms=aw_items, channel_locations=locations) - - if not dp.hide_unit_selector: - v_units_table = generate_unit_table_view(dp.waveform_extractor.sorting) - - view = vv.Box( - direction="horizontal", - items=[vv.LayoutItem(v_units_table, max_size=150), vv.LayoutItem(v_average_waveforms)], - ) - else: - view = v_average_waveforms - - self.handle_display_and_url(view, **backend_kwargs) - return view - - -UnitTemplatesPlotter.register(UnitTemplatesWidget) From d2d5a9cdc016845c11dbdaa50e2a7e39a4275a62 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 20 Jul 2023 10:55:24 +0200 Subject: [PATCH 22/31] some clean --- src/spikeinterface/widgets/__init__.py | 34 ------ .../widgets/all_amplitudes_distributions.py | 2 +- src/spikeinterface/widgets/amplitudes.py | 6 +- .../widgets/autocorrelograms.py | 4 +- src/spikeinterface/widgets/base.py | 58 +-------- .../widgets/crosscorrelograms.py | 4 +- src/spikeinterface/widgets/metrics.py | 6 +- src/spikeinterface/widgets/motion.py | 2 +- src/spikeinterface/widgets/sorting_summary.py | 2 +- src/spikeinterface/widgets/spike_locations.py | 6 +- .../widgets/spikes_on_traces.py | 4 +- .../widgets/template_similarity.py | 4 +- src/spikeinterface/widgets/timeseries.py | 6 +- src/spikeinterface/widgets/unit_depths.py | 2 +- src/spikeinterface/widgets/unit_locations.py | 6 +- src/spikeinterface/widgets/unit_summary.py | 2 +- src/spikeinterface/widgets/unit_templates.py | 2 +- src/spikeinterface/widgets/unit_waveforms.py | 4 +- .../widgets/unit_waveforms_density_map.py | 2 +- ...pywidgets_utils.py => utils_ipywidgets.py} | 0 ...atplotlib_utils.py => utils_matplotlib.py} | 0 ...tingview_utils.py => utils_sortingview.py} | 0 src/spikeinterface/widgets/widget_list.py | 113 ++++-------------- 23 files changed, 64 insertions(+), 205 deletions(-) rename src/spikeinterface/widgets/{ipywidgets_utils.py => utils_ipywidgets.py} (100%) rename src/spikeinterface/widgets/{matplotlib_utils.py => utils_matplotlib.py} (100%) rename src/spikeinterface/widgets/{sortingview_utils.py => utils_sortingview.py} (100%) diff --git a/src/spikeinterface/widgets/__init__.py b/src/spikeinterface/widgets/__init__.py index bb779ff7fb..d3066f51fa 100644 --- a/src/spikeinterface/widgets/__init__.py +++ b/src/spikeinterface/widgets/__init__.py @@ -1,37 +1,3 @@ -# check if backend are available -# try: -# import matplotlib - -# HAVE_MPL = True -# except: -# HAVE_MPL = False - -# try: -# import sortingview - -# HAVE_SV = True -# except: -# HAVE_SV = False - -# try: -# import ipywidgets - -# HAVE_IPYW = True -# except: -# HAVE_IPYW = False - - -# # theses import make the Widget.resgister() at import time -# if HAVE_MPL: -# import spikeinterface.widgets.matplotlib - -# if HAVE_SV: -# import spikeinterface.widgets.sortingview - -# if HAVE_IPYW: -# import spikeinterface.widgets.ipywidgets - -# when importing widget list backend are already registered from .widget_list import * # general functions diff --git a/src/spikeinterface/widgets/all_amplitudes_distributions.py b/src/spikeinterface/widgets/all_amplitudes_distributions.py index d3cca278c9..56aaa77804 100644 --- a/src/spikeinterface/widgets/all_amplitudes_distributions.py +++ b/src/spikeinterface/widgets/all_amplitudes_distributions.py @@ -50,7 +50,7 @@ def __init__( def plot_matplotlib(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt - from .matplotlib_utils import make_mpl_figure + from .utils_matplotlib import make_mpl_figure from matplotlib.patches import Ellipse from matplotlib.lines import Line2D diff --git a/src/spikeinterface/widgets/amplitudes.py b/src/spikeinterface/widgets/amplitudes.py index a2a3ccff3b..2be71f7470 100644 --- a/src/spikeinterface/widgets/amplitudes.py +++ b/src/spikeinterface/widgets/amplitudes.py @@ -115,7 +115,7 @@ def __init__( def plot_matplotlib(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt - from .matplotlib_utils import make_mpl_figure + from .utils_matplotlib import make_mpl_figure from probeinterface.plotting import plot_probe from matplotlib.patches import Ellipse @@ -182,7 +182,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt import ipywidgets.widgets as widgets from IPython.display import display - from .ipywidgets_utils import check_ipywidget_backend, make_unit_controller + from .utils_ipywidgets import check_ipywidget_backend, make_unit_controller check_ipywidget_backend() @@ -269,7 +269,7 @@ def _update_ipywidget(self, change): def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv - from .sortingview_utils import generate_unit_table_view, make_serializable, handle_display_and_url + from .utils_sortingview import generate_unit_table_view, make_serializable, handle_display_and_url # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) dp = to_attr(data_plot) diff --git a/src/spikeinterface/widgets/autocorrelograms.py b/src/spikeinterface/widgets/autocorrelograms.py index e7b5014367..ecb015bee2 100644 --- a/src/spikeinterface/widgets/autocorrelograms.py +++ b/src/spikeinterface/widgets/autocorrelograms.py @@ -11,7 +11,7 @@ def __init__(self, *args, **kargs): def plot_matplotlib(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt - from .matplotlib_utils import make_mpl_figure + from .utils_matplotlib import make_mpl_figure dp = to_attr(data_plot) # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) @@ -37,7 +37,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv - from .sortingview_utils import make_serializable, handle_display_and_url + from .utils_sortingview import make_serializable, handle_display_and_url # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) dp = to_attr(data_plot) diff --git a/src/spikeinterface/widgets/base.py b/src/spikeinterface/widgets/base.py index 7c52e1f993..eaa151ccd9 100644 --- a/src/spikeinterface/widgets/base.py +++ b/src/spikeinterface/widgets/base.py @@ -49,9 +49,6 @@ def set_default_plotter_backend(backend): class BaseWidget: - # this need to be reset in the subclass - possible_backends = None - def __init__( self, data_plot=None, @@ -79,6 +76,12 @@ def __init__( if immediate_plot: self.do_plot() + # subclass must define one method per supported backend: + # def plot_matplotlib(self, data_plot, **backend_kwargs): + # def plot_ipywidgets(self, data_plot, **backend_kwargs): + # def plot_sortingview(self, data_plot, **backend_kwargs): + + @classmethod def get_possible_backends(cls): return [k for k in default_backend_kwargs if hasattr(cls, f"plot_{k}")] @@ -91,25 +94,10 @@ def check_backend(self, backend): ) return backend - # def check_backend_kwargs(self, plotter, backend, **backend_kwargs): - # plotter_kwargs = plotter.default_backend_kwargs - # for k in backend_kwargs: - # if k not in plotter_kwargs: - # raise Exception( - # f"{k} is not a valid plot argument or backend keyword argument. " - # f"Possible backend keyword arguments for {backend} are: {list(plotter_kwargs.keys())}" - # ) - def do_plot(self): - # backend = self.check_backend(backend) - func = getattr(self, f"plot_{self.backend}") func(self.data_plot, **self.backend_kwargs) - # @classmethod - # def register_backend(cls, backend_plotter): - # cls.possible_backends[backend_plotter.backend] = backend_plotter - @staticmethod def check_extensions(waveform_extractor, extensions): if isinstance(extensions, str): @@ -127,27 +115,6 @@ def check_extensions(waveform_extractor, extensions): raise Exception(error_msg) -# class BackendPlotter: -# backend = "" - -# @classmethod -# def register(cls, widget_cls): -# widget_cls.register_backend(cls) - -# def update_backend_kwargs(self, **backend_kwargs): -# backend_kwargs_ = self.default_backend_kwargs.copy() -# backend_kwargs_.update(backend_kwargs) -# return backend_kwargs_ - - -# def copy_signature(source_fct): -# def copy(target_fct): -# target_fct.__signature__ = inspect.signature(source_fct) -# return target_fct - -# return copy - - class to_attr(object): def __init__(self, d): """ @@ -164,16 +131,3 @@ def __init__(self, d): def __getattribute__(self, k): d = object.__getattribute__(self, "__d") return d[k] - - -# def define_widget_function_from_class(widget_class, name): -# @copy_signature(widget_class) -# def widget_func(*args, **kwargs): -# W = widget_class(*args, **kwargs) -# W.do_plot(W.backend, **W.backend_kwargs) -# return W.plotter - -# widget_func.__doc__ = widget_class.__doc__ -# widget_func.__name__ = name - -# return widget_func diff --git a/src/spikeinterface/widgets/crosscorrelograms.py b/src/spikeinterface/widgets/crosscorrelograms.py index 4b83e61b69..5635466a2d 100644 --- a/src/spikeinterface/widgets/crosscorrelograms.py +++ b/src/spikeinterface/widgets/crosscorrelograms.py @@ -68,7 +68,7 @@ def __init__( def plot_matplotlib(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt - from .matplotlib_utils import make_mpl_figure + from .utils_matplotlib import make_mpl_figure dp = to_attr(data_plot) # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) @@ -104,7 +104,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv - from .sortingview_utils import generate_unit_table_view, make_serializable, handle_display_and_url + from .utils_sortingview import generate_unit_table_view, make_serializable, handle_display_and_url # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) dp = to_attr(data_plot) diff --git a/src/spikeinterface/widgets/metrics.py b/src/spikeinterface/widgets/metrics.py index 6551bb067e..3d5e247b93 100644 --- a/src/spikeinterface/widgets/metrics.py +++ b/src/spikeinterface/widgets/metrics.py @@ -81,7 +81,7 @@ def __init__( def plot_matplotlib(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt - from .matplotlib_utils import make_mpl_figure + from .utils_matplotlib import make_mpl_figure dp = to_attr(data_plot) metrics = dp.metrics @@ -132,7 +132,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt import ipywidgets.widgets as widgets from IPython.display import display - from .ipywidgets_utils import check_ipywidget_backend, make_unit_controller + from .utils_ipywidgets import check_ipywidget_backend, make_unit_controller check_ipywidget_backend() @@ -228,7 +228,7 @@ def _update_ipywidget(self, change): def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv - from .sortingview_utils import generate_unit_table_view, make_serializable, handle_display_and_url + from .utils_sortingview import generate_unit_table_view, make_serializable, handle_display_and_url dp = to_attr(data_plot) # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index 1ebbb71743..6420fe8848 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -71,7 +71,7 @@ def __init__( def plot_matplotlib(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt - from .matplotlib_utils import make_mpl_figure + from .utils_matplotlib import make_mpl_figure from matplotlib.colors import Normalize from spikeinterface.sortingcomponents.motion_interpolation import correct_motion_on_peaks diff --git a/src/spikeinterface/widgets/sorting_summary.py b/src/spikeinterface/widgets/sorting_summary.py index 5498df9a33..9b4279d94e 100644 --- a/src/spikeinterface/widgets/sorting_summary.py +++ b/src/spikeinterface/widgets/sorting_summary.py @@ -85,7 +85,7 @@ def __init__( def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv - from .sortingview_utils import generate_unit_table_view, make_serializable, handle_display_and_url + from .utils_sortingview import generate_unit_table_view, make_serializable, handle_display_and_url dp = to_attr(data_plot) we = dp.waveform_extractor diff --git a/src/spikeinterface/widgets/spike_locations.py b/src/spikeinterface/widgets/spike_locations.py index 06495409cf..62feff9372 100644 --- a/src/spikeinterface/widgets/spike_locations.py +++ b/src/spikeinterface/widgets/spike_locations.py @@ -107,7 +107,7 @@ def __init__( def plot_matplotlib(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt - from .matplotlib_utils import make_mpl_figure + from .utils_matplotlib import make_mpl_figure from matplotlib.lines import Line2D from probeinterface import ProbeGroup @@ -195,7 +195,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt import ipywidgets.widgets as widgets from IPython.display import display - from .ipywidgets_utils import check_ipywidget_backend, make_unit_controller + from .utils_ipywidgets import check_ipywidget_backend, make_unit_controller check_ipywidget_backend() @@ -272,7 +272,7 @@ def _update_ipywidget(self, change): def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv - from .sortingview_utils import generate_unit_table_view, make_serializable, handle_display_and_url + from .utils_sortingview import generate_unit_table_view, make_serializable, handle_display_and_url # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) dp = to_attr(data_plot) diff --git a/src/spikeinterface/widgets/spikes_on_traces.py b/src/spikeinterface/widgets/spikes_on_traces.py index 0aeb923f38..74fc7f7501 100644 --- a/src/spikeinterface/widgets/spikes_on_traces.py +++ b/src/spikeinterface/widgets/spikes_on_traces.py @@ -163,7 +163,7 @@ def __init__( def plot_matplotlib(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt - from .matplotlib_utils import make_mpl_figure + from .utils_matplotlib import make_mpl_figure from matplotlib.patches import Ellipse from matplotlib.lines import Line2D @@ -286,7 +286,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt import ipywidgets.widgets as widgets from IPython.display import display - from .ipywidgets_utils import check_ipywidget_backend, make_unit_controller + from .utils_ipywidgets import check_ipywidget_backend, make_unit_controller check_ipywidget_backend() diff --git a/src/spikeinterface/widgets/template_similarity.py b/src/spikeinterface/widgets/template_similarity.py index a6e0356db1..f43a47db62 100644 --- a/src/spikeinterface/widgets/template_similarity.py +++ b/src/spikeinterface/widgets/template_similarity.py @@ -65,7 +65,7 @@ def __init__( def plot_matplotlib(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt - from .matplotlib_utils import make_mpl_figure + from .utils_matplotlib import make_mpl_figure dp = to_attr(data_plot) # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) @@ -89,7 +89,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv - from .sortingview_utils import generate_unit_table_view, make_serializable, handle_display_and_url + from .utils_sortingview import generate_unit_table_view, make_serializable, handle_display_and_url # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) dp = to_attr(data_plot) diff --git a/src/spikeinterface/widgets/timeseries.py b/src/spikeinterface/widgets/timeseries.py index 86e886babc..7165dec12a 100644 --- a/src/spikeinterface/widgets/timeseries.py +++ b/src/spikeinterface/widgets/timeseries.py @@ -218,7 +218,7 @@ def __init__( def plot_matplotlib(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt from matplotlib.ticker import MaxNLocator - from .matplotlib_utils import make_mpl_figure + from .utils_matplotlib import make_mpl_figure dp = to_attr(data_plot) # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) @@ -284,7 +284,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt import ipywidgets.widgets as widgets from IPython.display import display - from .ipywidgets_utils import ( + from .utils_ipywidgets import ( check_ipywidget_backend, make_timeseries_controller, make_channel_controller, @@ -506,7 +506,7 @@ def _update_ipywidget(self, change): def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv - from .sortingview_utils import generate_unit_table_view, make_serializable, handle_display_and_url + from .utils_sortingview import generate_unit_table_view, make_serializable, handle_display_and_url try: import pyvips diff --git a/src/spikeinterface/widgets/unit_depths.py b/src/spikeinterface/widgets/unit_depths.py index faf9198c0d..9bcafb53e4 100644 --- a/src/spikeinterface/widgets/unit_depths.py +++ b/src/spikeinterface/widgets/unit_depths.py @@ -59,7 +59,7 @@ def __init__( def plot_matplotlib(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt - from .matplotlib_utils import make_mpl_figure + from .utils_matplotlib import make_mpl_figure dp = to_attr(data_plot) # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) diff --git a/src/spikeinterface/widgets/unit_locations.py b/src/spikeinterface/widgets/unit_locations.py index 9e35f7b32c..b923374a07 100644 --- a/src/spikeinterface/widgets/unit_locations.py +++ b/src/spikeinterface/widgets/unit_locations.py @@ -84,7 +84,7 @@ def __init__( def plot_matplotlib(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt - from .matplotlib_utils import make_mpl_figure + from .utils_matplotlib import make_mpl_figure from probeinterface.plotting import plot_probe from matplotlib.patches import Ellipse @@ -170,7 +170,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt import ipywidgets.widgets as widgets from IPython.display import display - from .ipywidgets_utils import check_ipywidget_backend, make_unit_controller + from .utils_ipywidgets import check_ipywidget_backend, make_unit_controller check_ipywidget_backend() @@ -242,7 +242,7 @@ def _update_ipywidget(self, change): def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv - from .sortingview_utils import generate_unit_table_view, make_serializable, handle_display_and_url + from .utils_sortingview import generate_unit_table_view, make_serializable, handle_display_and_url # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) dp = to_attr(data_plot) diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index 66f522e3ca..82e3e79fb9 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -106,7 +106,7 @@ def __init__( def plot_matplotlib(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt - from .matplotlib_utils import make_mpl_figure + from .utils_matplotlib import make_mpl_figure dp = to_attr(data_plot) diff --git a/src/spikeinterface/widgets/unit_templates.py b/src/spikeinterface/widgets/unit_templates.py index 04b26e300f..7e9a1c21a8 100644 --- a/src/spikeinterface/widgets/unit_templates.py +++ b/src/spikeinterface/widgets/unit_templates.py @@ -11,7 +11,7 @@ def __init__(self, *args, **kargs): def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv - from .sortingview_utils import generate_unit_table_view, make_serializable, handle_display_and_url + from .utils_sortingview import generate_unit_table_view, make_serializable, handle_display_and_url dp = to_attr(data_plot) # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) diff --git a/src/spikeinterface/widgets/unit_waveforms.py b/src/spikeinterface/widgets/unit_waveforms.py index 833f13881d..f82d276d92 100644 --- a/src/spikeinterface/widgets/unit_waveforms.py +++ b/src/spikeinterface/widgets/unit_waveforms.py @@ -167,7 +167,7 @@ def __init__( def plot_matplotlib(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt - from .matplotlib_utils import make_mpl_figure + from .utils_matplotlib import make_mpl_figure from probeinterface.plotting import plot_probe from matplotlib.patches import Ellipse @@ -260,7 +260,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt import ipywidgets.widgets as widgets from IPython.display import display - from .ipywidgets_utils import check_ipywidget_backend, make_unit_controller + from .utils_ipywidgets import check_ipywidget_backend, make_unit_controller check_ipywidget_backend() diff --git a/src/spikeinterface/widgets/unit_waveforms_density_map.py b/src/spikeinterface/widgets/unit_waveforms_density_map.py index 9216373d87..3320a232c6 100644 --- a/src/spikeinterface/widgets/unit_waveforms_density_map.py +++ b/src/spikeinterface/widgets/unit_waveforms_density_map.py @@ -159,7 +159,7 @@ def __init__( def plot_matplotlib(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt - from .matplotlib_utils import make_mpl_figure + from .utils_matplotlib import make_mpl_figure dp = to_attr(data_plot) # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) diff --git a/src/spikeinterface/widgets/ipywidgets_utils.py b/src/spikeinterface/widgets/utils_ipywidgets.py similarity index 100% rename from src/spikeinterface/widgets/ipywidgets_utils.py rename to src/spikeinterface/widgets/utils_ipywidgets.py diff --git a/src/spikeinterface/widgets/matplotlib_utils.py b/src/spikeinterface/widgets/utils_matplotlib.py similarity index 100% rename from src/spikeinterface/widgets/matplotlib_utils.py rename to src/spikeinterface/widgets/utils_matplotlib.py diff --git a/src/spikeinterface/widgets/sortingview_utils.py b/src/spikeinterface/widgets/utils_sortingview.py similarity index 100% rename from src/spikeinterface/widgets/sortingview_utils.py rename to src/spikeinterface/widgets/utils_sortingview.py diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index a753c78d4a..eab0345d53 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -1,81 +1,44 @@ -# from .base import define_widget_function_from_class from .base import backend_kwargs_desc -# basics -from .timeseries import TimeseriesWidget - -# waveform -from .unit_waveforms import UnitWaveformsWidget -from .unit_templates import UnitTemplatesWidget -from .unit_waveforms_density_map import UnitWaveformDensityMapWidget - -# isi/ccg/acg +from .all_amplitudes_distributions import AllAmplitudesDistributionsWidget +from .amplitudes import AmplitudesWidget from .autocorrelograms import AutoCorrelogramsWidget from .crosscorrelograms import CrossCorrelogramsWidget - -# peak activity - -# drift/motion - -# spikes-traces -from .spikes_on_traces import SpikesOnTracesWidget - -# PC related - -# units on probe -from .unit_locations import UnitLocationsWidget -from .spike_locations import SpikeLocationsWidget - -# unit presence - - -# comparison related - -# correlogram comparison - -# amplitudes -from .amplitudes import AmplitudesWidget -from .all_amplitudes_distributions import AllAmplitudesDistributionsWidget - -# metrics +from .motion import MotionWidget from .quality_metrics import QualityMetricsWidget +from .sorting_summary import SortingSummaryWidget +from .spike_locations import SpikeLocationsWidget +from .spikes_on_traces import SpikesOnTracesWidget from .template_metrics import TemplateMetricsWidget - - -# motion/drift -from .motion import MotionWidget - -# similarity from .template_similarity import TemplateSimilarityWidget - - +from .timeseries import TimeseriesWidget from .unit_depths import UnitDepthsWidget - -# summary +from .unit_locations import UnitLocationsWidget from .unit_summary import UnitSummaryWidget -from .sorting_summary import SortingSummaryWidget +from .unit_templates import UnitTemplatesWidget +from .unit_waveforms_density_map import UnitWaveformDensityMapWidget +from .unit_waveforms import UnitWaveformsWidget widget_list = [ - AmplitudesWidget, AllAmplitudesDistributionsWidget, + AmplitudesWidget, AutoCorrelogramsWidget, CrossCorrelogramsWidget, + MotionWidget, QualityMetricsWidget, + SortingSummaryWidget, SpikeLocationsWidget, SpikesOnTracesWidget, TemplateMetricsWidget, - MotionWidget, TemplateSimilarityWidget, TimeseriesWidget, + UnitDepthsWidget, UnitLocationsWidget, + UnitSummaryWidget, UnitTemplatesWidget, - UnitWaveformsWidget, UnitWaveformDensityMapWidget, - UnitDepthsWidget, - # summary - UnitSummaryWidget, - SortingSummaryWidget, + UnitWaveformsWidget, ] @@ -105,45 +68,21 @@ # make function for all widgets -# plot_amplitudes = define_widget_function_from_class(AmplitudesWidget, "plot_amplitudes") -# plot_all_amplitudes_distributions = define_widget_function_from_class( -# AllAmplitudesDistributionsWidget, "plot_all_amplitudes_distributions" -# ) -# plot_autocorrelograms = define_widget_function_from_class(AutoCorrelogramsWidget, "plot_autocorrelograms") -# plot_crosscorrelograms = define_widget_function_from_class(CrossCorrelogramsWidget, "plot_crosscorrelograms") -# plot_quality_metrics = define_widget_function_from_class(QualityMetricsWidget, "plot_quality_metrics") -# plot_spike_locations = define_widget_function_from_class(SpikeLocationsWidget, "plot_spike_locations") -# plot_spikes_on_traces = define_widget_function_from_class(SpikesOnTracesWidget, "plot_spikes_on_traces") -# plot_template_metrics = define_widget_function_from_class(TemplateMetricsWidget, "plot_template_metrics") -# plot_motion = define_widget_function_from_class(MotionWidget, "plot_motion") -# plot_template_similarity = define_widget_function_from_class(TemplateSimilarityWidget, "plot_template_similarity") -# plot_timeseries = define_widget_function_from_class(TimeseriesWidget, "plot_timeseries") -# plot_unit_locations = define_widget_function_from_class(UnitLocationsWidget, "plot_unit_locations") -# plot_unit_templates = define_widget_function_from_class(UnitTemplatesWidget, "plot_unit_templates") -# plot_unit_waveforms = define_widget_function_from_class(UnitWaveformsWidget, "plot_unit_waveforms") -# plot_unit_waveforms_density_map = define_widget_function_from_class( -# UnitWaveformDensityMapWidget, "plot_unit_waveforms_density_map" -# ) -# plot_unit_depths = define_widget_function_from_class(UnitDepthsWidget, "plot_unit_depths") -# plot_unit_summary = define_widget_function_from_class(UnitSummaryWidget, "plot_unit_summary") -# plot_sorting_summary = define_widget_function_from_class(SortingSummaryWidget, "plot_sorting_summary") - - -plot_amplitudes = AmplitudesWidget plot_all_amplitudes_distributions = AllAmplitudesDistributionsWidget -plot_unit_locations = UnitLocationsWidget +plot_amplitudes = AmplitudesWidget plot_autocorrelograms = AutoCorrelogramsWidget plot_crosscorrelograms = CrossCorrelogramsWidget +plot_motion = MotionWidget +plot_quality_metrics = QualityMetricsWidget +plot_sorting_summary = SortingSummaryWidget plot_spike_locations = SpikeLocationsWidget plot_spikes_on_traces = SpikesOnTracesWidget plot_template_metrics = TemplateMetricsWidget -plot_timeseries = TimeseriesWidget -plot_quality_metrics = QualityMetricsWidget -plot_motion = MotionWidget plot_template_similarity = TemplateSimilarityWidget -plot_unit_templates = UnitTemplatesWidget -plot_unit_waveforms = UnitWaveformsWidget -plot_unit_waveforms_density_map = UnitWaveformDensityMapWidget +plot_timeseries = TimeseriesWidget plot_unit_depths = UnitDepthsWidget +plot_unit_locations = UnitLocationsWidget plot_unit_summary = UnitSummaryWidget -plot_sorting_summary = SortingSummaryWidget +plot_unit_templates = UnitTemplatesWidget +plot_unit_waveforms_density_map = UnitWaveformDensityMapWidget +plot_unit_waveforms = UnitWaveformsWidget From 91064c4d30a185c24a33d9eeee2dbd681eab91f9 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 20 Jul 2023 11:16:57 +0200 Subject: [PATCH 23/31] More clean --- .../widgets/all_amplitudes_distributions.py | 5 +- src/spikeinterface/widgets/amplitudes.py | 26 +------ .../widgets/autocorrelograms.py | 11 +-- .../widgets/crosscorrelograms.py | 10 +-- src/spikeinterface/widgets/metrics.py | 20 +---- src/spikeinterface/widgets/motion.py | 9 --- src/spikeinterface/widgets/quality_metrics.py | 1 - src/spikeinterface/widgets/sorting_summary.py | 45 +----------- src/spikeinterface/widgets/spike_locations.py | 21 +----- .../widgets/spikes_on_traces.py | 73 ------------------- .../widgets/template_metrics.py | 2 - .../widgets/template_similarity.py | 9 +-- src/spikeinterface/widgets/timeseries.py | 26 +------ src/spikeinterface/widgets/unit_depths.py | 3 +- src/spikeinterface/widgets/unit_locations.py | 14 +--- src/spikeinterface/widgets/unit_summary.py | 60 --------------- src/spikeinterface/widgets/unit_templates.py | 8 +- src/spikeinterface/widgets/unit_waveforms.py | 23 ------ .../widgets/unit_waveforms_density_map.py | 5 -- .../widgets/utils_matplotlib.py | 8 -- .../widgets/utils_sortingview.py | 8 -- 21 files changed, 17 insertions(+), 370 deletions(-) diff --git a/src/spikeinterface/widgets/all_amplitudes_distributions.py b/src/spikeinterface/widgets/all_amplitudes_distributions.py index 56aaa77804..280662fd7a 100644 --- a/src/spikeinterface/widgets/all_amplitudes_distributions.py +++ b/src/spikeinterface/widgets/all_amplitudes_distributions.py @@ -21,8 +21,6 @@ class AllAmplitudesDistributionsWidget(BaseWidget): Dict of colors with key: unit, value: color, default None """ - possible_backends = {} - def __init__( self, waveform_extractor: WaveformExtractor, unit_ids=None, unit_colors=None, backend=None, **backend_kwargs ): @@ -56,8 +54,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): from matplotlib.lines import Line2D dp = to_attr(data_plot) - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - # self.make_mpl_figure(**backend_kwargs) + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) ax = self.ax diff --git a/src/spikeinterface/widgets/amplitudes.py b/src/spikeinterface/widgets/amplitudes.py index 2be71f7470..7ef6e0ff61 100644 --- a/src/spikeinterface/widgets/amplitudes.py +++ b/src/spikeinterface/widgets/amplitudes.py @@ -35,8 +35,6 @@ class AmplitudesWidget(BaseWidget): True includes legend in plot, default True """ - possible_backends = {} - def __init__( self, waveform_extractor: WaveformExtractor, @@ -116,13 +114,8 @@ def __init__( def plot_matplotlib(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt from .utils_matplotlib import make_mpl_figure - from probeinterface.plotting import plot_probe - - from matplotlib.patches import Ellipse - from matplotlib.lines import Line2D dp = to_attr(data_plot) - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) if backend_kwargs["axes"] is not None: axes = backend_kwargs["axes"] @@ -139,7 +132,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): else: backend_kwargs["num_axes"] = None - # self.make_mpl_figure(**backend_kwargs) self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) scatter_ax = self.axes.flatten()[0] @@ -164,7 +156,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): self.figure.tight_layout() if dp.plot_legend: - # if self.legend is not None: if hasattr(self, "legend") and self.legend is not None: self.legend.remove() self.legend = self.figure.legend( @@ -191,7 +182,6 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): cm = 1 / 2.54 we = data_plot["waveform_extractor"] - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) width_cm = backend_kwargs["width_cm"] height_cm = backend_kwargs["height_cm"] @@ -200,7 +190,6 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): with plt.ioff(): output = widgets.Output() with output: - # fig = plt.figure(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) self.figure = plt.figure(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) plt.show() @@ -220,15 +209,10 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.controller = {"plot_histograms": plot_histograms} self.controller.update(unit_controller) - # mpl_plotter = MplAmplitudesPlotter() - - # self.updater = PlotUpdater(data_plot, mpl_plotter, fig, self.controller) for w in self.controller.values(): - # w.observe(self.updater) w.observe(self._update_ipywidget) self.widget = widgets.AppLayout( - # center=fig.canvas, left_sidebar=unit_widget, pane_widths=ratios + [0], footer=footer center=self.figure.canvas, left_sidebar=unit_widget, pane_widths=ratios + [0], @@ -236,15 +220,12 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): ) # a first update - # self.updater(None) self._update_ipywidget(None) if backend_kwargs["display"]: - # self.check_backend() display(self.widget) def _update_ipywidget(self, change): - # self.fig.clear() self.figure.clear() unit_ids = self.controller["unit_ids"].value @@ -261,7 +242,6 @@ def _update_ipywidget(self, change): backend_kwargs["axes"] = None backend_kwargs["ax"] = None - # self.mpl_plotter.do_plot(data_plot, **backend_kwargs) self.plot_matplotlib(data_plot, **backend_kwargs) self.figure.canvas.draw() @@ -271,10 +251,8 @@ def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv from .utils_sortingview import generate_unit_table_view, make_serializable, handle_display_and_url - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) dp = to_attr(data_plot) - # unit_ids = self.make_serializable(dp.unit_ids) unit_ids = make_serializable(dp.unit_ids) sa_items = [ @@ -286,10 +264,8 @@ def plot_sortingview(self, data_plot, **backend_kwargs): for u in unit_ids ] - # v_spike_amplitudes = vv.SpikeAmplitudes( self.view = vv.SpikeAmplitudes( start_time_sec=0, end_time_sec=dp.total_duration, plots=sa_items, hide_unit_selector=dp.hide_unit_selector ) - # self.handle_display_and_url(v_spike_amplitudes, **backend_kwargs) - self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) + self.url = handle_display_and_url(self, self.view, **backend_kwargs) diff --git a/src/spikeinterface/widgets/autocorrelograms.py b/src/spikeinterface/widgets/autocorrelograms.py index ecb015bee2..e98abbed8f 100644 --- a/src/spikeinterface/widgets/autocorrelograms.py +++ b/src/spikeinterface/widgets/autocorrelograms.py @@ -4,7 +4,7 @@ class AutoCorrelogramsWidget(CrossCorrelogramsWidget): - # possible_backends = {} + # the doc is copied form CrossCorrelogramsWidget def __init__(self, *args, **kargs): CrossCorrelogramsWidget.__init__(self, *args, **kargs) @@ -14,12 +14,9 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): from .utils_matplotlib import make_mpl_figure dp = to_attr(data_plot) - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) backend_kwargs["num_axes"] = len(dp.unit_ids) self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - # self.make_mpl_figure(**backend_kwargs) - bins = dp.bins unit_ids = dp.unit_ids correlograms = dp.correlograms @@ -39,9 +36,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv from .utils_sortingview import make_serializable, handle_display_and_url - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) dp = to_attr(data_plot) - # unit_ids = self.make_serializable(dp.unit_ids) unit_ids = make_serializable(dp.unit_ids) ac_items = [] @@ -58,9 +53,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs): self.view = vv.Autocorrelograms(autocorrelograms=ac_items) - # self.handle_display_and_url(v_autocorrelograms, **backend_kwargs) - # return v_autocorrelograms - self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) + self.url = handle_display_and_url(self, self.view, **backend_kwargs) AutoCorrelogramsWidget.__doc__ = CrossCorrelogramsWidget.__doc__ diff --git a/src/spikeinterface/widgets/crosscorrelograms.py b/src/spikeinterface/widgets/crosscorrelograms.py index 5635466a2d..3ec3fa11b6 100644 --- a/src/spikeinterface/widgets/crosscorrelograms.py +++ b/src/spikeinterface/widgets/crosscorrelograms.py @@ -27,8 +27,6 @@ class CrossCorrelogramsWidget(BaseWidget): If given, a dictionary with unit ids as keys and colors as values, default None """ - # possible_backends = {} - def __init__( self, waveform_or_sorting_extractor: Union[WaveformExtractor, BaseSorting], @@ -71,11 +69,9 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): from .utils_matplotlib import make_mpl_figure dp = to_attr(data_plot) - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) backend_kwargs["ncols"] = len(dp.unit_ids) backend_kwargs["num_axes"] = int(len(dp.unit_ids) ** 2) - # self.make_mpl_figure(**backend_kwargs) self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) assert self.axes.ndim == 2 @@ -106,10 +102,8 @@ def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv from .utils_sortingview import generate_unit_table_view, make_serializable, handle_display_and_url - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) dp = to_attr(data_plot) - # unit_ids = self.make_serializable(dp.unit_ids) unit_ids = make_serializable(dp.unit_ids) cc_items = [] @@ -126,6 +120,4 @@ def plot_sortingview(self, data_plot, **backend_kwargs): self.view = vv.CrossCorrelograms(cross_correlograms=cc_items, hide_unit_selector=dp.hide_unit_selector) - # self.handle_display_and_url(v_cross_correlograms, **backend_kwargs) - # return v_cross_correlograms - self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) + self.url = handle_display_and_url(self, self.view, **backend_kwargs) diff --git a/src/spikeinterface/widgets/metrics.py b/src/spikeinterface/widgets/metrics.py index 3d5e247b93..9dc51f522e 100644 --- a/src/spikeinterface/widgets/metrics.py +++ b/src/spikeinterface/widgets/metrics.py @@ -30,8 +30,6 @@ class MetricsBaseWidget(BaseWidget): If True, metrics data are included in unit table, by default True """ - # possible_backends = {} - def __init__( self, metrics, @@ -90,13 +88,11 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): if "figsize" not in backend_kwargs: backend_kwargs["figsize"] = (2 * num_metrics, 2 * num_metrics) - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) backend_kwargs["num_axes"] = num_metrics**2 backend_kwargs["ncols"] = num_metrics all_unit_ids = metrics.index.values - # self.make_mpl_figure(**backend_kwargs) self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) assert self.axes.ndim == 2 @@ -160,11 +156,6 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.controller = unit_controller - # mpl_plotter = MplMetricsPlotter() - - # self.updater = PlotUpdater(data_plot, mpl_plotter, fig, self.controller) - # for w in self.controller.values(): - # w.observe(self.updater) for w in self.controller.values(): w.observe(self._update_ipywidget) @@ -175,11 +166,9 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): ) # a first update - # self.updater(None) self._update_ipywidget(None) if backend_kwargs["display"]: - # self.check_backend() display(self.widget) def _update_ipywidget(self, change): @@ -199,16 +188,13 @@ def _update_ipywidget(self, change): sizes.append(size) # here we do a trick: we just update colors - # if hasattr(self.mpl_plotter, "patches"): if hasattr(self, "patches"): - # for p in self.mpl_plotter.patches: for p in self.patches: p.set_color(colors) p.set_sizes(sizes) else: backend_kwargs = {} backend_kwargs["figure"] = self.figure - # self.mpl_plotter.do_plot(self.data_plot, **backend_kwargs) self.plot_matplotlib(self.data_plot, **backend_kwargs) if len(unit_ids) > 0: @@ -231,7 +217,6 @@ def plot_sortingview(self, data_plot, **backend_kwargs): from .utils_sortingview import generate_unit_table_view, make_serializable, handle_display_and_url dp = to_attr(data_plot) - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) metrics = dp.metrics metric_names = list(metrics.columns) @@ -240,7 +225,6 @@ def plot_sortingview(self, data_plot, **backend_kwargs): unit_ids = metrics.index.values else: unit_ids = dp.unit_ids - # unit_ids = self.make_serializable(unit_ids) unit_ids = make_serializable(unit_ids) metrics_sv = [] @@ -279,6 +263,4 @@ def plot_sortingview(self, data_plot, **backend_kwargs): else: self.view = v_metrics - # self.handle_display_and_url(view, **backend_kwargs) - # return view - self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) + self.url = handle_display_and_url(self, self.view, **backend_kwargs) diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index 6420fe8848..cb11bcce0c 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -1,11 +1,6 @@ import numpy as np -from warnings import warn from .base import BaseWidget, to_attr -from .utils import get_unit_colors - - -from ..core.template_tools import get_template_extremum_amplitude class MotionWidget(BaseWidget): @@ -36,8 +31,6 @@ class MotionWidget(BaseWidget): The alpha of the scatter points, default 0.5 """ - # possible_backends = {} - def __init__( self, motion_info, @@ -77,12 +70,10 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): from spikeinterface.sortingcomponents.motion_interpolation import correct_motion_on_peaks dp = to_attr(data_plot) - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) assert backend_kwargs["axes"] is None assert backend_kwargs["ax"] is None - # self.make_mpl_figure(**backend_kwargs) self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) fig = self.figure fig.clear() diff --git a/src/spikeinterface/widgets/quality_metrics.py b/src/spikeinterface/widgets/quality_metrics.py index 46bcd6c07b..459a32e6f2 100644 --- a/src/spikeinterface/widgets/quality_metrics.py +++ b/src/spikeinterface/widgets/quality_metrics.py @@ -1,6 +1,5 @@ from .metrics import MetricsBaseWidget from ..core.waveform_extractor import WaveformExtractor -from ..qualitymetrics import compute_quality_metrics class QualityMetricsWidget(MetricsBaseWidget): diff --git a/src/spikeinterface/widgets/sorting_summary.py b/src/spikeinterface/widgets/sorting_summary.py index 9b4279d94e..9291de2956 100644 --- a/src/spikeinterface/widgets/sorting_summary.py +++ b/src/spikeinterface/widgets/sorting_summary.py @@ -9,7 +9,7 @@ from .unit_templates import UnitTemplatesWidget -from ..core import WaveformExtractor, ChannelSparsity +from ..core import WaveformExtractor class SortingSummaryWidget(BaseWidget): @@ -55,26 +55,10 @@ def __init__( if unit_ids is None: unit_ids = sorting.get_unit_ids() - # use other widgets to generate data (except for similarity) - # template_plot_data = UnitTemplatesWidget( - # we, unit_ids=unit_ids, sparsity=sparsity, hide_unit_selector=True - # ).plot_data - # ccg_plot_data = CrossCorrelogramsWidget(we, unit_ids=unit_ids, hide_unit_selector=True).plot_data - # amps_plot_data = AmplitudesWidget( - # we, unit_ids=unit_ids, max_spikes_per_unit=max_amplitudes_per_unit, hide_unit_selector=True - # ).plot_data - # locs_plot_data = UnitLocationsWidget(we, unit_ids=unit_ids, hide_unit_selector=True).plot_data - # sim_plot_data = TemplateSimilarityWidget(we, unit_ids=unit_ids).plot_data - plot_data = dict( waveform_extractor=waveform_extractor, unit_ids=unit_ids, sparsity=sparsity, - # templates=template_plot_data, - # correlograms=ccg_plot_data, - # amplitudes=amps_plot_data, - # similarity=sim_plot_data, - # unit_locations=locs_plot_data, unit_table_properties=unit_table_properties, curation=curation, label_choices=label_choices, @@ -92,28 +76,8 @@ def plot_sortingview(self, data_plot, **backend_kwargs): unit_ids = dp.unit_ids sparsity = dp.sparsity - # unit_ids = self.make_serializable(dp.unit_ids) unit_ids = make_serializable(dp.unit_ids) - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - - # amplitudes_plotter = AmplitudesPlotter() - # v_spike_amplitudes = amplitudes_plotter.do_plot( - # dp.amplitudes, generate_url=False, display=False, backend="sortingview" - # ) - # template_plotter = UnitTemplatesPlotter() - # v_average_waveforms = template_plotter.do_plot( - # dp.templates, generate_url=False, display=False, backend="sortingview" - # ) - # xcorrelograms_plotter = CrossCorrelogramsPlotter() - # v_cross_correlograms = xcorrelograms_plotter.do_plot( - # dp.correlograms, generate_url=False, display=False, backend="sortingview" - # ) - # unitlocation_plotter = UnitLocationsPlotter() - # v_unit_locations = unitlocation_plotter.do_plot( - # dp.unit_locations, generate_url=False, display=False, backend="sortingview" - # ) - v_spike_amplitudes = AmplitudesWidget( we, unit_ids=unit_ids, @@ -144,7 +108,6 @@ def plot_sortingview(self, data_plot, **backend_kwargs): we, unit_ids=unit_ids, immediate_plot=False, generate_url=False, display=False, backend="sortingview" ) similarity = w.data_plot["similarity"] - print(similarity.shape) # similarity similarity_scores = [] @@ -183,10 +146,6 @@ def plot_sortingview(self, data_plot, **backend_kwargs): ) # assemble layout - # v_summary = vv.Splitter(direction="horizontal", item1=vv.LayoutItem(v1), item2=vv.LayoutItem(v2)) self.view = vv.Splitter(direction="horizontal", item1=vv.LayoutItem(v1), item2=vv.LayoutItem(v2)) - # self.handle_display_and_url(v_summary, **backend_kwargs) - # return v_summary - - self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) + self.url = handle_display_and_url(self, self.view, **backend_kwargs) diff --git a/src/spikeinterface/widgets/spike_locations.py b/src/spikeinterface/widgets/spike_locations.py index 62feff9372..9771b2c0e9 100644 --- a/src/spikeinterface/widgets/spike_locations.py +++ b/src/spikeinterface/widgets/spike_locations.py @@ -1,5 +1,4 @@ import numpy as np -from typing import Union from .base import BaseWidget, to_attr from .utils import get_unit_colors @@ -114,9 +113,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): from probeinterface.plotting import plot_probe dp = to_attr(data_plot) - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - # self.make_mpl_figure(**backend_kwargs) self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) spike_locations = dp.spike_locations @@ -168,7 +165,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): for unit in dp.unit_ids ] if dp.plot_legend: - # if self.legend is not None: if hasattr(self, "legend") and self.legend is not None: self.legend.remove() self.legend = self.figure.legend( @@ -203,7 +199,6 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): cm = 1 / 2.54 - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) width_cm = backend_kwargs["width_cm"] height_cm = backend_kwargs["height_cm"] @@ -226,12 +221,6 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.controller = unit_controller - # mpl_plotter = MplSpikeLocationsPlotter() - - # self.updater = PlotUpdater(data_plot, mpl_plotter, ax, self.controller) - # for w in self.controller.values(): - # w.observe(self.updater) - for w in self.controller.values(): w.observe(self._update_ipywidget) @@ -242,11 +231,9 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): ) # a first update - # self.updater(None) self._update_ipywidget(None) if backend_kwargs["display"]: - # self.check_backend() display(self.widget) def _update_ipywidget(self, change): @@ -274,12 +261,10 @@ def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv from .utils_sortingview import generate_unit_table_view, make_serializable, handle_display_and_url - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) dp = to_attr(data_plot) spike_locations = dp.spike_locations # ensure serializable for sortingview - # unit_ids, channel_ids = self.make_serializable(dp.unit_ids, dp.channel_ids) unit_ids, channel_ids = make_serializable(dp.unit_ids, dp.channel_ids) locations = {str(ch): dp.channel_locations[i_ch].astype("float32") for i_ch, ch in enumerate(channel_ids)} @@ -321,11 +306,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs): else: self.view = v_spike_locations - # self.set_view(view) - - # self.handle_display_and_url(view, **backend_kwargs) - # return view - self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) + self.url = handle_display_and_url(self, self.view, **backend_kwargs) def estimate_axis_lims(spike_locations, quantile=0.02): diff --git a/src/spikeinterface/widgets/spikes_on_traces.py b/src/spikeinterface/widgets/spikes_on_traces.py index 74fc7f7501..ab4e629a2e 100644 --- a/src/spikeinterface/widgets/spikes_on_traces.py +++ b/src/spikeinterface/widgets/spikes_on_traces.py @@ -60,8 +60,6 @@ class SpikesOnTracesWidget(BaseWidget): For 'map' mode and sortingview backend, seconds to render in each row, default 0.2 """ - # possible_backends = {} - def __init__( self, waveform_extractor: WaveformExtractor, @@ -86,29 +84,8 @@ def __init__( **backend_kwargs, ): we = waveform_extractor - # recording: BaseRecording = we.recording sorting: BaseSorting = we.sorting - # ts_widget = TimeseriesWidget( - # recording, - # segment_index, - # channel_ids, - # order_channel_by_depth, - # time_range, - # mode, - # return_scaled, - # cmap, - # show_channel_ids, - # color_groups, - # color, - # clim, - # tile_size, - # seconds_per_row, - # with_colorbar, - # backend, - # **backend_kwargs, - # ) - if unit_ids is None: unit_ids = sorting.get_unit_ids() unit_ids = unit_ids @@ -150,7 +127,6 @@ def __init__( ) plot_data = dict( - # timeseries=ts_widget.plot_data, waveform_extractor=waveform_extractor, options=options, unit_ids=unit_ids, @@ -173,14 +149,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): recording = we.recording sorting = we.sorting - # first plot time series - # tsplotter = TimeseriesPlotter() - # data_plot["timeseries"]["add_legend"] = False - # tsplotter.do_plot(dp.timeseries, **backend_kwargs) - # self.ax = tsplotter.ax - # self.axes = tsplotter.axes - # self.figure = tsplotter.figure - # first plot time series ts_widget = TimeseriesWidget(recording, **dp.options, backend="matplotlib", **backend_kwargs) self.ax = ts_widget.ax @@ -189,20 +157,11 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax = self.ax - # we = dp.waveform_extractor - # sorting = dp.waveform_extractor.sorting - # frame_range = dp.timeseries["frame_range"] - # segment_index = dp.timeseries["segment_index"] - # min_y = np.min(dp.timeseries["channel_locations"][:, 1]) - # max_y = np.max(dp.timeseries["channel_locations"][:, 1]) - frame_range = ts_widget.data_plot["frame_range"] segment_index = ts_widget.data_plot["segment_index"] min_y = np.min(ts_widget.data_plot["channel_locations"][:, 1]) max_y = np.max(ts_widget.data_plot["channel_locations"][:, 1]) - # n = len(dp.timeseries["channel_ids"]) - # order = dp.timeseries["order"] n = len(ts_widget.data_plot["channel_ids"]) order = ts_widget.data_plot["order"] @@ -224,7 +183,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): spike_frames_to_plot = spike_frames[spike_start:spike_end] - # if dp.timeseries["mode"] == "map": if dp.options["mode"] == "map": spike_times_to_plot = sorting.get_unit_spike_train( unit, segment_index=segment_index, return_times=True @@ -253,16 +211,12 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): # construct waveforms label_set = False if len(spike_frames_to_plot) > 0: - # vspacing = dp.timeseries["vspacing"] - # traces = dp.timeseries["list_traces"][0] vspacing = ts_widget.data_plot["vspacing"] traces = ts_widget.data_plot["list_traces"][0] waveform_idxs = spike_frames_to_plot[:, None] + np.arange(-we.nbefore, we.nafter) - frame_range[0] - # waveform_idxs = np.clip(waveform_idxs, 0, len(dp.timeseries["times"]) - 1) waveform_idxs = np.clip(waveform_idxs, 0, len(ts_widget.data_plot["times"]) - 1) - # times = dp.timeseries["times"][waveform_idxs] times = ts_widget.data_plot["times"][waveform_idxs] # discontinuity @@ -271,7 +225,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): waveforms = traces[waveform_idxs] # [:, :, order] waveforms_r = waveforms.reshape((waveforms.shape[0] * waveforms.shape[1], waveforms.shape[2])) - # for i, chan_id in enumerate(dp.timeseries["channel_ids"]): for i, chan_id in enumerate(ts_widget.data_plot["channel_ids"]): offset = vspacing * i if chan_id in chan_ids: @@ -296,7 +249,6 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): we = dp.waveform_extractor ratios = [0.2, 0.8] - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) backend_kwargs_ts = backend_kwargs.copy() backend_kwargs_ts["width_cm"] = ratios[1] * backend_kwargs_ts["width_cm"] @@ -305,46 +257,28 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): width_cm = backend_kwargs["width_cm"] # plot timeseries - # tsplotter = TimeseriesPlotter() - # data_plot["timeseries"]["add_legend"] = False - # tsplotter.do_plot(data_plot["timeseries"], **backend_kwargs_ts) - - # ts_w = tsplotter.widget - # ts_updater = tsplotter.updater - ts_widget = TimeseriesWidget(we.recording, **dp.options, backend="ipywidgets", **backend_kwargs_ts) self.ax = ts_widget.ax self.axes = ts_widget.axes self.figure = ts_widget.figure - # we = data_plot["waveform_extractor"] - unit_widget, unit_controller = make_unit_controller( data_plot["unit_ids"], we.unit_ids, ratios[0] * width_cm, height_cm ) self.controller = dict() - # self.controller = ts_updater.controller self.controller.update(ts_widget.controller) self.controller.update(unit_controller) - # mpl_plotter = MplSpikesOnTracesPlotter() - - # self.updater = PlotUpdater(data_plot, mpl_plotter, ts_updater, self.controller) - # for w in self.controller.values(): - # w.observe(self.updater) - for w in self.controller.values(): w.observe(self._update_ipywidget) self.widget = widgets.AppLayout(center=ts_widget.widget, left_sidebar=unit_widget, pane_widths=ratios + [0]) # a first update - # self.updater(None) self._update_ipywidget(None) if backend_kwargs["display"]: - # self.check_backend() display(self.widget) def _update_ipywidget(self, change): @@ -352,19 +286,12 @@ def _update_ipywidget(self, change): unit_ids = self.controller["unit_ids"].value - # update ts - # self.ts_updater.__call__(change) - - # update data plot - # data_plot = self.data_plot.copy() data_plot = self.next_data_plot - # data_plot["timeseries"] = self.ts_updater.next_data_plot data_plot["unit_ids"] = unit_ids backend_kwargs = {} backend_kwargs["ax"] = self.ax - # self.mpl_plotter.do_plot(data_plot, **backend_kwargs) self.plot_matplotlib(data_plot, **backend_kwargs) self.figure.canvas.draw() diff --git a/src/spikeinterface/widgets/template_metrics.py b/src/spikeinterface/widgets/template_metrics.py index 7361757666..748babb57d 100644 --- a/src/spikeinterface/widgets/template_metrics.py +++ b/src/spikeinterface/widgets/template_metrics.py @@ -22,8 +22,6 @@ class TemplateMetricsWidget(MetricsBaseWidget): For sortingview backend, if True the unit selector is not displayed, default False """ - # possible_backends = {} - def __init__( self, waveform_extractor: WaveformExtractor, diff --git a/src/spikeinterface/widgets/template_similarity.py b/src/spikeinterface/widgets/template_similarity.py index f43a47db62..69aad70b1f 100644 --- a/src/spikeinterface/widgets/template_similarity.py +++ b/src/spikeinterface/widgets/template_similarity.py @@ -1,5 +1,4 @@ import numpy as np -from typing import Union from .base import BaseWidget, to_attr from ..core.waveform_extractor import WaveformExtractor @@ -68,9 +67,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): from .utils_matplotlib import make_mpl_figure dp = to_attr(data_plot) - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - # self.make_mpl_figure(**backend_kwargs) self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) im = self.ax.matshow(dp.similarity, cmap=dp.cmap) @@ -91,11 +88,9 @@ def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv from .utils_sortingview import generate_unit_table_view, make_serializable, handle_display_and_url - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) dp = to_attr(data_plot) # ensure serializable for sortingview - # unit_ids = self.make_serializable(dp.unit_ids) unit_ids = make_serializable(dp.unit_ids) # similarity @@ -108,6 +103,4 @@ def plot_sortingview(self, data_plot, **backend_kwargs): self.view = vv.UnitSimilarityMatrix(unit_ids=list(unit_ids), similarity_scores=ss_items) - # self.handle_display_and_url(view, **backend_kwargs) - # return view - self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) + self.url = handle_display_and_url(self, self.view, **backend_kwargs) diff --git a/src/spikeinterface/widgets/timeseries.py b/src/spikeinterface/widgets/timeseries.py index 7165dec12a..9439694639 100644 --- a/src/spikeinterface/widgets/timeseries.py +++ b/src/spikeinterface/widgets/timeseries.py @@ -58,8 +58,6 @@ class TimeseriesWidget(BaseWidget): The output widget """ - # possible_backends = {} - def __init__( self, recording, @@ -221,9 +219,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): from .utils_matplotlib import make_mpl_figure dp = to_attr(data_plot) - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - # self.make_mpl_figure(**backend_kwargs) self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) ax = self.ax @@ -302,7 +298,6 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): cm = 1 / 2.54 - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) width_cm = backend_kwargs["width_cm"] height_cm = backend_kwargs["height_cm"] ratios = [0.1, 0.8, 0.2] @@ -335,15 +330,6 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.controller.update(ch_controller) self.controller.update(scale_controller) - # mpl_plotter = MplTimeseriesPlotter() - - # self.updater = PlotUpdater(data_plot, mpl_plotter, ax, self.controller) - # for w in self.controller.values(): - # if isinstance(w, widgets.Button): - # w.on_click(self.updater) - # else: - # w.observe(self.updater) - self.recordings = data_plot["recordings"] self.return_scaled = data_plot["return_scaled"] self.list_traces = None @@ -371,7 +357,6 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): ) # a first update - # self.updater(None) self._update_ipywidget(None) if backend_kwargs["display"]: @@ -497,7 +482,7 @@ def _update_ipywidget(self, change): backend_kwargs = {} backend_kwargs["ax"] = self.ax - # self.mpl_plotter.do_plot(data_plot, **backend_kwargs) + self.plot_matplotlib(data_plot, **backend_kwargs) fig = self.ax.figure @@ -506,7 +491,7 @@ def _update_ipywidget(self, change): def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv - from .utils_sortingview import generate_unit_table_view, make_serializable, handle_display_and_url + from .utils_sortingview import handle_display_and_url try: import pyvips @@ -536,17 +521,12 @@ def plot_sortingview(self, data_plot, **backend_kwargs): tiled_layers.append(vv.TiledImageLayer(layer_key, img)) - # view_ts = vv.TiledImage(tile_size=dp.tile_size, layers=tiled_layers) self.view = vv.TiledImage(tile_size=dp.tile_size, layers=tiled_layers) - # self.set_view(view_ts) - # timeseries currently doesn't display on the jupyter backend backend_kwargs["display"] = False - # self.handle_display_and_url(view_ts, **backend_kwargs) - # return view_ts - self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) + self.url = handle_display_and_url(self, self.view, **backend_kwargs) def _get_trace_list(recordings, channel_ids, time_range, segment_index, order=None, return_scaled=False): diff --git a/src/spikeinterface/widgets/unit_depths.py b/src/spikeinterface/widgets/unit_depths.py index 9bcafb53e4..e48f274962 100644 --- a/src/spikeinterface/widgets/unit_depths.py +++ b/src/spikeinterface/widgets/unit_depths.py @@ -62,8 +62,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): from .utils_matplotlib import make_mpl_figure dp = to_attr(data_plot) - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - # self.make_mpl_figure(**backend_kwargs) + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) ax = self.ax diff --git a/src/spikeinterface/widgets/unit_locations.py b/src/spikeinterface/widgets/unit_locations.py index b923374a07..f8ea042f84 100644 --- a/src/spikeinterface/widgets/unit_locations.py +++ b/src/spikeinterface/widgets/unit_locations.py @@ -121,7 +121,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): self.ax.set_title("") - # color = np.array([dp.unit_colors[unit_id] for unit_id in dp.unit_ids]) width = height = 10 ellipse_kwargs = dict(width=width, height=height, lw=2) @@ -178,8 +177,6 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): cm = 1 / 2.54 - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - width_cm = backend_kwargs["width_cm"] height_cm = backend_kwargs["height_cm"] @@ -198,12 +195,6 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.controller = unit_controller - # mpl_plotter = MplUnitLocationsPlotter() - - # self.updater = PlotUpdater(data_plot, mpl_plotter, ax, self.controller) - # for w in self.controller.values(): - # w.observe(self.updater) - for w in self.controller.values(): w.observe(self._update_ipywidget) @@ -234,7 +225,6 @@ def _update_ipywidget(self, change): backend_kwargs = {} backend_kwargs["ax"] = self.ax - # self.mpl_plotter.do_plot(data_plot, **backend_kwargs) self.plot_matplotlib(data_plot, **backend_kwargs) fig = self.ax.get_figure() fig.canvas.draw() @@ -244,7 +234,6 @@ def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv from .utils_sortingview import generate_unit_table_view, make_serializable, handle_display_and_url - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) dp = to_attr(data_plot) # ensure serializable for sortingview @@ -272,5 +261,4 @@ def plot_sortingview(self, data_plot, **backend_kwargs): else: self.view = v_unit_locations - # self.handle_display_and_url(view, **backend_kwargs) - self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) + self.url = handle_display_and_url(self, self.view, **backend_kwargs) diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index 82e3e79fb9..964b5813e6 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -1,5 +1,4 @@ import numpy as np -from typing import Union from .base import BaseWidget, to_attr from .utils import get_unit_colors @@ -48,58 +47,11 @@ def __init__( if unit_colors is None: unit_colors = get_unit_colors(we.sorting) - # if we.is_extension("unit_locations"): - # plot_data_unit_locations = UnitLocationsWidget( - # we, unit_ids=[unit_id], unit_colors=unit_colors, plot_legend=False - # ).plot_data - # unit_locations = waveform_extractor.load_extension("unit_locations").get_data(outputs="by_unit") - # unit_location = unit_locations[unit_id] - # else: - # plot_data_unit_locations = None - # unit_location = None - - # plot_data_waveforms = UnitWaveformsWidget( - # we, - # unit_ids=[unit_id], - # unit_colors=unit_colors, - # plot_templates=True, - # same_axis=True, - # plot_legend=False, - # sparsity=sparsity, - # ).plot_data - - # plot_data_waveform_density = UnitWaveformDensityMapWidget( - # we, unit_ids=[unit_id], unit_colors=unit_colors, use_max_channel=True, plot_templates=True, same_axis=False - # ).plot_data - - # if we.is_extension("correlograms"): - # plot_data_acc = AutoCorrelogramsWidget( - # we, - # unit_ids=[unit_id], - # unit_colors=unit_colors, - # ).plot_data - # else: - # plot_data_acc = None - - # use other widget to plot data - # if we.is_extension("spike_amplitudes"): - # plot_data_amplitudes = AmplitudesWidget( - # we, unit_ids=[unit_id], unit_colors=unit_colors, plot_legend=False, plot_histograms=True - # ).plot_data - # else: - # plot_data_amplitudes = None - plot_data = dict( we=we, unit_id=unit_id, unit_colors=unit_colors, sparsity=sparsity, - # unit_location=unit_location, - # plot_data_unit_locations=plot_data_unit_locations, - # plot_data_waveforms=plot_data_waveforms, - # plot_data_waveform_density=plot_data_waveform_density, - # plot_data_acc=plot_data_acc, - # plot_data_amplitudes=plot_data_amplitudes, ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) @@ -118,27 +70,22 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): # force the figure without axes if "figsize" not in backend_kwargs: backend_kwargs["figsize"] = (18, 7) - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) backend_kwargs["num_axes"] = 0 backend_kwargs["ax"] = None backend_kwargs["axes"] = None - # self.make_mpl_figure(**backend_kwargs) self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) # and use custum grid spec fig = self.figure nrows = 2 ncols = 3 - # if dp.plot_data_acc is not None or dp.plot_data_amplitudes is not None: if we.is_extension("correlograms") or we.is_extension("spike_amplitudes"): ncols += 1 - # if dp.plot_data_amplitudes is not None : if we.is_extension("spike_amplitudes"): nrows += 1 gs = fig.add_gridspec(nrows, ncols) - # if dp.plot_data_unit_locations is not None: if we.is_extension("unit_locations"): ax1 = fig.add_subplot(gs[:2, 0]) # UnitLocationsPlotter().do_plot(dp.plot_data_unit_locations, ax=ax1) @@ -148,7 +95,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): unit_locations = we.load_extension("unit_locations").get_data(outputs="by_unit") unit_location = unit_locations[unit_id] - # x, y = dp.unit_location[0], dp.unit_location[1] x, y = unit_location[0], unit_location[1] ax1.set_xlim(x - 80, x + 80) ax1.set_ylim(y - 250, y + 250) @@ -157,7 +103,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax1.set_ylabel(None) ax2 = fig.add_subplot(gs[:2, 1]) - # UnitWaveformPlotter().do_plot(dp.plot_data_waveforms, ax=ax2) w = UnitWaveformsWidget( we, unit_ids=[unit_id], @@ -173,7 +118,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax2.set_title(None) ax3 = fig.add_subplot(gs[:2, 2]) - # UnitWaveformDensityMapPlotter().do_plot(dp.plot_data_waveform_density, ax=ax3) UnitWaveformDensityMapWidget( we, unit_ids=[unit_id], @@ -185,10 +129,8 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ) ax3.set_ylabel(None) - # if dp.plot_data_acc is not None: if we.is_extension("correlograms"): ax4 = fig.add_subplot(gs[:2, 3]) - # AutoCorrelogramsPlotter().do_plot(dp.plot_data_acc, ax=ax4) AutoCorrelogramsWidget( we, unit_ids=[unit_id], @@ -200,12 +142,10 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax4.set_title(None) ax4.set_yticks([]) - # if dp.plot_data_amplitudes is not None: if we.is_extension("spike_amplitudes"): ax5 = fig.add_subplot(gs[2, :3]) ax6 = fig.add_subplot(gs[2, 3]) axes = np.array([ax5, ax6]) - # AmplitudesPlotter().do_plot(dp.plot_data_amplitudes, axes=axes) AmplitudesWidget( we, unit_ids=[unit_id], diff --git a/src/spikeinterface/widgets/unit_templates.py b/src/spikeinterface/widgets/unit_templates.py index 7e9a1c21a8..cf58e91aa0 100644 --- a/src/spikeinterface/widgets/unit_templates.py +++ b/src/spikeinterface/widgets/unit_templates.py @@ -3,7 +3,7 @@ class UnitTemplatesWidget(UnitWaveformsWidget): - # possible_backends = {} + # doc is copied from UnitWaveformsWidget def __init__(self, *args, **kargs): kargs["plot_waveforms"] = False @@ -14,13 +14,11 @@ def plot_sortingview(self, data_plot, **backend_kwargs): from .utils_sortingview import generate_unit_table_view, make_serializable, handle_display_and_url dp = to_attr(data_plot) - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) # ensure serializable for sortingview unit_id_to_channel_ids = dp.sparsity.unit_id_to_channel_ids unit_id_to_channel_indices = dp.sparsity.unit_id_to_channel_indices - # unit_ids, channel_ids = self.make_serializable(dp.unit_ids, dp.channel_ids) unit_ids, channel_ids = make_serializable(dp.unit_ids, dp.channel_ids) templates_dict = {} @@ -52,9 +50,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs): else: self.view = v_average_waveforms - # self.handle_display_and_url(view, **backend_kwargs) - # return view - self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) + self.url = handle_display_and_url(self, self.view, **backend_kwargs) UnitTemplatesWidget.__doc__ = UnitWaveformsWidget.__doc__ diff --git a/src/spikeinterface/widgets/unit_waveforms.py b/src/spikeinterface/widgets/unit_waveforms.py index f82d276d92..e64765b44b 100644 --- a/src/spikeinterface/widgets/unit_waveforms.py +++ b/src/spikeinterface/widgets/unit_waveforms.py @@ -59,8 +59,6 @@ class UnitWaveformsWidget(BaseWidget): Display legend, default True """ - # possible_backends = {} - def __init__( self, waveform_extractor: WaveformExtractor, @@ -168,15 +166,9 @@ def __init__( def plot_matplotlib(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt from .utils_matplotlib import make_mpl_figure - from probeinterface.plotting import plot_probe - - from matplotlib.patches import Ellipse - from matplotlib.lines import Line2D dp = to_attr(data_plot) - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - if backend_kwargs.get("axes", None) is not None: assert len(backend_kwargs["axes"]) >= len(dp.unit_ids), "Provide as many 'axes' as neurons" elif backend_kwargs.get("ax", None) is not None: @@ -189,7 +181,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): backend_kwargs["num_axes"] = len(dp.unit_ids) backend_kwargs["ncols"] = min(dp.ncols, len(dp.unit_ids)) - # self.make_mpl_figure(**backend_kwargs) self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) for i, unit_id in enumerate(dp.unit_ids): @@ -249,7 +240,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax.scatter(dp.channel_locations[:, 0], dp.channel_locations[:, 1], color="k") if dp.same_axis and dp.plot_legend: - # if self.legend is not None: if hasattr(self, "legend") and self.legend is not None: self.legend.remove() self.legend = self.figure.legend( @@ -269,7 +259,6 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): cm = 1 / 2.54 self.we = we = data_plot["waveform_extractor"] - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) width_cm = backend_kwargs["width_cm"] height_cm = backend_kwargs["height_cm"] @@ -317,12 +306,6 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): } self.controller.update(unit_controller) - # mpl_plotter = MplUnitWaveformPlotter() - - # self.updater = PlotUpdater(data_plot, mpl_plotter, fig_wf, ax_probe, self.controller) - # for w in self.controller.values(): - # w.observe(self.updater) - for w in self.controller.values(): w.observe(self._update_ipywidget) @@ -335,11 +318,9 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): ) # a first update - # self.updater(None) self._update_ipywidget(None) if backend_kwargs["display"]: - # self.check_backend() display(self.widget) def _update_ipywidget(self, change): @@ -369,18 +350,14 @@ def _update_ipywidget(self, change): else: backend_kwargs["figure"] = self.fig_wf - # self.mpl_plotter.do_plot(data_plot, **backend_kwargs) self.plot_matplotlib(data_plot, **backend_kwargs) if same_axis: - # self.mpl_plotter.ax.axis("equal") self.ax.axis("equal") if hide_axis: - # self.mpl_plotter.ax.axis("off") self.ax.axis("off") else: if hide_axis: for i in range(len(unit_ids)): - # ax = self.mpl_plotter.axes.flatten()[i] ax = self.axes.flatten()[i] ax.axis("off") diff --git a/src/spikeinterface/widgets/unit_waveforms_density_map.py b/src/spikeinterface/widgets/unit_waveforms_density_map.py index 3320a232c6..e8a6868e92 100644 --- a/src/spikeinterface/widgets/unit_waveforms_density_map.py +++ b/src/spikeinterface/widgets/unit_waveforms_density_map.py @@ -33,8 +33,6 @@ class UnitWaveformDensityMapWidget(BaseWidget): all channel per units, default False """ - # possible_backends = {} - def __init__( self, waveform_extractor, @@ -162,10 +160,8 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): from .utils_matplotlib import make_mpl_figure dp = to_attr(data_plot) - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) if backend_kwargs["axes"] is not None or backend_kwargs["ax"] is not None: - # self.make_mpl_figure(**backend_kwargs) self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) else: if dp.same_axis: @@ -174,7 +170,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): num_axes = len(dp.unit_ids) backend_kwargs["ncols"] = 1 backend_kwargs["num_axes"] = num_axes - # self.make_mpl_figure(**backend_kwargs) self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) if dp.same_axis: diff --git a/src/spikeinterface/widgets/utils_matplotlib.py b/src/spikeinterface/widgets/utils_matplotlib.py index fb347552b1..a9128d7b66 100644 --- a/src/spikeinterface/widgets/utils_matplotlib.py +++ b/src/spikeinterface/widgets/utils_matplotlib.py @@ -65,11 +65,3 @@ def make_mpl_figure(figure=None, ax=None, axes=None, ncols=None, num_axes=None, figure.suptitle(figtitle) return figure, axes, ax - - # self.figure = figure - # self.ax = ax - # axes is always a 2D array of ax - # self.axes = axes - - # if figtitle is not None: - # self.figure.suptitle(figtitle) diff --git a/src/spikeinterface/widgets/utils_sortingview.py b/src/spikeinterface/widgets/utils_sortingview.py index 764246becf..24ae481a6b 100644 --- a/src/spikeinterface/widgets/utils_sortingview.py +++ b/src/spikeinterface/widgets/utils_sortingview.py @@ -3,14 +3,6 @@ from ..core.core_tools import check_json -sortingview_backend_kwargs_desc = { - "generate_url": "If True, the figurl URL is generated and printed. Default True", - "display": "If True and in jupyter notebook/lab, the widget is displayed in the cell. Default True.", - "figlabel": "The figurl figure label. Default None", - "height": "The height of the sortingview View in jupyter. Default None", -} -sortingview_default_backend_kwargs = {"generate_url": True, "display": True, "figlabel": None, "height": None} - def make_serializable(*args): dict_to_serialize = {int(i): a for i, a in enumerate(args)} From 69af6b41d37b206adab8cdb5e8f198c5f2f0f9ab Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 20 Jul 2023 11:25:51 +0200 Subject: [PATCH 24/31] plot_timeseries > plot_traces --- doc/api.rst | 2 +- doc/how_to/analyse_neuropixels.rst | 10 +++++----- doc/how_to/get_started.rst | 2 +- doc/modules/widgets.rst | 8 ++++---- examples/how_to/analyse_neuropixels.py | 10 +++++----- examples/how_to/get_started.py | 2 +- .../extractors/plot_1_read_various_formats.py | 2 +- .../widgets/plot_1_rec_gallery.py | 10 +++++----- .../extractors/tests/test_cbin_ibl_extractors.py | 2 +- .../preprocessing/tests/test_filter.py | 8 ++++---- .../preprocessing/tests/test_normalize_scale.py | 2 +- .../preprocessing/tests/test_phase_shift.py | 6 +++--- .../preprocessing/tests/test_rectify.py | 2 +- .../benchmark/benchmark_peak_localization.py | 2 +- .../widgets/_legacy_mpl_widgets/__init__.py | 2 +- .../widgets/_legacy_mpl_widgets/amplitudes.py | 6 +++--- .../widgets/_legacy_mpl_widgets/timeseries_.py | 8 ++++---- src/spikeinterface/widgets/spikes_on_traces.py | 6 +++--- src/spikeinterface/widgets/tests/test_widgets.py | 16 ++++++++-------- .../widgets/{timeseries.py => traces.py} | 10 +++++----- src/spikeinterface/widgets/widget_list.py | 13 ++++++++++--- 21 files changed, 68 insertions(+), 61 deletions(-) rename src/spikeinterface/widgets/{timeseries.py => traces.py} (98%) diff --git a/doc/api.rst b/doc/api.rst index e0a863bd9c..932c989c19 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -275,7 +275,7 @@ spikeinterface.widgets .. autofunction:: plot_spikes_on_traces .. autofunction:: plot_template_metrics .. autofunction:: plot_template_similarity - .. autofunction:: plot_timeseries + .. autofunction:: plot_traces .. autofunction:: plot_unit_depths .. autofunction:: plot_unit_locations .. autofunction:: plot_unit_summary diff --git a/doc/how_to/analyse_neuropixels.rst b/doc/how_to/analyse_neuropixels.rst index 0a02a47211..31dbc7422c 100644 --- a/doc/how_to/analyse_neuropixels.rst +++ b/doc/how_to/analyse_neuropixels.rst @@ -264,7 +264,7 @@ the ipywydgets interactive ploter .. code:: python %matplotlib widget - si.plot_timeseries({'filter':rec1, 'cmr': rec4}, backend='ipywidgets') + si.plot_traces({'filter':rec1, 'cmr': rec4}, backend='ipywidgets') Note that using this ipywydgets make possible to explore diffrents preprocessing chain wihtout to save the entire file to disk. Everything @@ -276,9 +276,9 @@ is lazy, so you can change the previsous cell (parameters, step order, # here we use static plot using matplotlib backend fig, axs = plt.subplots(ncols=3, figsize=(20, 10)) - si.plot_timeseries(rec1, backend='matplotlib', clim=(-50, 50), ax=axs[0]) - si.plot_timeseries(rec4, backend='matplotlib', clim=(-50, 50), ax=axs[1]) - si.plot_timeseries(rec, backend='matplotlib', clim=(-50, 50), ax=axs[2]) + si.plot_traces(rec1, backend='matplotlib', clim=(-50, 50), ax=axs[0]) + si.plot_traces(rec4, backend='matplotlib', clim=(-50, 50), ax=axs[1]) + si.plot_traces(rec, backend='matplotlib', clim=(-50, 50), ax=axs[2]) for i, label in enumerate(('filter', 'cmr', 'final')): axs[i].set_title(label) @@ -292,7 +292,7 @@ is lazy, so you can change the previsous cell (parameters, step order, # plot some channels fig, ax = plt.subplots(figsize=(20, 10)) some_chans = rec.channel_ids[[100, 150, 200, ]] - si.plot_timeseries({'filter':rec1, 'cmr': rec4}, backend='matplotlib', mode='line', ax=ax, channel_ids=some_chans) + si.plot_traces({'filter':rec1, 'cmr': rec4}, backend='matplotlib', mode='line', ax=ax, channel_ids=some_chans) diff --git a/doc/how_to/get_started.rst b/doc/how_to/get_started.rst index 02ccb872d1..0f6aa9eb3f 100644 --- a/doc/how_to/get_started.rst +++ b/doc/how_to/get_started.rst @@ -104,7 +104,7 @@ and the raster plots. .. code:: ipython3 - w_ts = sw.plot_timeseries(recording, time_range=(0, 5)) + w_ts = sw.plot_traces(recording, time_range=(0, 5)) w_rs = sw.plot_rasters(sorting_true, time_range=(0, 5)) diff --git a/doc/modules/widgets.rst b/doc/modules/widgets.rst index 9cb99ab5a1..86c541dfd0 100644 --- a/doc/modules/widgets.rst +++ b/doc/modules/widgets.rst @@ -123,7 +123,7 @@ The :code:`plot_*(..., backend="matplotlib")` functions come with the following .. code-block:: python # matplotlib backend - w = plot_timeseries(recording, backend="matplotlib") + w = plot_traces(recording, backend="matplotlib") **Output:** @@ -146,9 +146,9 @@ Each function has the following additional arguments: from spikeinterface.preprocessing import common_reference - # ipywidgets backend also supports multiple "layers" for plot_timeseries + # ipywidgets backend also supports multiple "layers" for plot_traces rec_dict = dict(filt=recording, cmr=common_reference(recording)) - w = sw.plot_timeseries(rec_dict, backend="ipywidgets") + w = sw.plot_traces(rec_dict, backend="ipywidgets") **Output:** @@ -171,7 +171,7 @@ The functions have the following additional arguments: .. code-block:: python # sortingview backend - w_ts = sw.plot_timeseries(recording, backend="ipywidgets") + w_ts = sw.plot_traces(recording, backend="ipywidgets") w_ss = sw.plot_sorting_summary(recording, backend="sortingview") diff --git a/examples/how_to/analyse_neuropixels.py b/examples/how_to/analyse_neuropixels.py index 9b9048cd0d..637120a591 100644 --- a/examples/how_to/analyse_neuropixels.py +++ b/examples/how_to/analyse_neuropixels.py @@ -82,7 +82,7 @@ # # ```python # # %matplotlib widget -# si.plot_timeseries({'filter':rec1, 'cmr': rec4}, backend='ipywidgets') +# si.plot_traces({'filter':rec1, 'cmr': rec4}, backend='ipywidgets') # ``` # # Note that using this ipywidgets make possible to explore different preprocessing chains without saving the entire file to disk. @@ -94,9 +94,9 @@ # here we use a static plot using matplotlib backend fig, axs = plt.subplots(ncols=3, figsize=(20, 10)) -si.plot_timeseries(rec1, backend='matplotlib', clim=(-50, 50), ax=axs[0]) -si.plot_timeseries(rec4, backend='matplotlib', clim=(-50, 50), ax=axs[1]) -si.plot_timeseries(rec, backend='matplotlib', clim=(-50, 50), ax=axs[2]) +si.plot_traces(rec1, backend='matplotlib', clim=(-50, 50), ax=axs[0]) +si.plot_traces(rec4, backend='matplotlib', clim=(-50, 50), ax=axs[1]) +si.plot_traces(rec, backend='matplotlib', clim=(-50, 50), ax=axs[2]) for i, label in enumerate(('filter', 'cmr', 'final')): axs[i].set_title(label) # - @@ -104,7 +104,7 @@ # plot some channels fig, ax = plt.subplots(figsize=(20, 10)) some_chans = rec.channel_ids[[100, 150, 200, ]] -si.plot_timeseries({'filter':rec1, 'cmr': rec4}, backend='matplotlib', mode='line', ax=ax, channel_ids=some_chans) +si.plot_traces({'filter':rec1, 'cmr': rec4}, backend='matplotlib', mode='line', ax=ax, channel_ids=some_chans) # ### Should we save the preprocessed data to a binary file? diff --git a/examples/how_to/get_started.py b/examples/how_to/get_started.py index 266d585de9..7860c605af 100644 --- a/examples/how_to/get_started.py +++ b/examples/how_to/get_started.py @@ -92,7 +92,7 @@ # # Let's use the `spikeinterface.widgets` module to visualize the traces and the raster plots. -w_ts = sw.plot_timeseries(recording, time_range=(0, 5)) +w_ts = sw.plot_traces(recording, time_range=(0, 5)) w_rs = sw.plot_rasters(sorting_true, time_range=(0, 5)) # This is how you retrieve info from a `BaseRecording`... diff --git a/examples/modules_gallery/extractors/plot_1_read_various_formats.py b/examples/modules_gallery/extractors/plot_1_read_various_formats.py index 98988a1746..ed0ba34396 100644 --- a/examples/modules_gallery/extractors/plot_1_read_various_formats.py +++ b/examples/modules_gallery/extractors/plot_1_read_various_formats.py @@ -87,7 +87,7 @@ import spikeinterface.widgets as sw -w_ts = sw.plot_timeseries(recording, time_range=(0, 5)) +w_ts = sw.plot_traces(recording, time_range=(0, 5)) w_rs = sw.plot_rasters(sorting, time_range=(0, 5)) plt.show() diff --git a/examples/modules_gallery/widgets/plot_1_rec_gallery.py b/examples/modules_gallery/widgets/plot_1_rec_gallery.py index d3d4792535..1544bbfc54 100644 --- a/examples/modules_gallery/widgets/plot_1_rec_gallery.py +++ b/examples/modules_gallery/widgets/plot_1_rec_gallery.py @@ -15,22 +15,22 @@ recording, sorting = se.toy_example(duration=10, num_channels=4, seed=0, num_segments=1) ############################################################################## -# plot_timeseries() +# plot_traces() # ~~~~~~~~~~~~~~~~~ -w_ts = sw.plot_timeseries(recording) +w_ts = sw.plot_traces(recording) ############################################################################## # We can select time range -w_ts1 = sw.plot_timeseries(recording, time_range=(5, 8)) +w_ts1 = sw.plot_traces(recording, time_range=(5, 8)) ############################################################################## # We can color with groups recording2 = recording.clone() recording2.set_channel_groups(channel_ids=recording.get_channel_ids(), groups=[0, 0, 1, 1]) -w_ts2 = sw.plot_timeseries(recording2, time_range=(5, 8), color_groups=True) +w_ts2 = sw.plot_traces(recording2, time_range=(5, 8), color_groups=True) ############################################################################## # **Note**: each function returns a widget object, which allows to access the figure and axis. @@ -41,7 +41,7 @@ ############################################################################## # We can also use the 'map' mode useful for high channel count -w_ts = sw.plot_timeseries(recording, mode='map', time_range=(5, 8), +w_ts = sw.plot_traces(recording, mode='map', time_range=(5, 8), show_channel_ids=True, order_channel_by_depth=True) ############################################################################## diff --git a/src/spikeinterface/extractors/tests/test_cbin_ibl_extractors.py b/src/spikeinterface/extractors/tests/test_cbin_ibl_extractors.py index 3c4e23f14a..2e364b13bc 100644 --- a/src/spikeinterface/extractors/tests/test_cbin_ibl_extractors.py +++ b/src/spikeinterface/extractors/tests/test_cbin_ibl_extractors.py @@ -22,7 +22,7 @@ class CompressedBinaryIblExtractorTest(RecordingCommonTestSuite, unittest.TestCa # ~ import matplotlib.pyplot as plt # ~ import spikeinterface.widgets as sw # ~ from probeinterface.plotting import plot_probe -# ~ sw.plot_timeseries(rec) +# ~ sw.plot_traces(rec) # ~ plot_probe(rec.get_probe()) # ~ plt.show() diff --git a/src/spikeinterface/preprocessing/tests/test_filter.py b/src/spikeinterface/preprocessing/tests/test_filter.py index 5d6cc0eb16..95e5a097ff 100644 --- a/src/spikeinterface/preprocessing/tests/test_filter.py +++ b/src/spikeinterface/preprocessing/tests/test_filter.py @@ -105,10 +105,10 @@ def test_filter_opencl(): # rec2_cached0 = rec2.save(chunk_size=1000,verbose=False, progress_bar=True, n_jobs=4) # import matplotlib.pyplot as plt - # from spikeinterface.widgets import plot_timeseries - # plot_timeseries(rec, segment_index=0) - # plot_timeseries(rec_filtered, segment_index=0) - # plot_timeseries(rec2_cached0, segment_index=0) + # from spikeinterface.widgets import plot_traces + # plot_traces(rec, segment_index=0) + # plot_traces(rec_filtered, segment_index=0) + # plot_traces(rec2_cached0, segment_index=0) # plt.show() diff --git a/src/spikeinterface/preprocessing/tests/test_normalize_scale.py b/src/spikeinterface/preprocessing/tests/test_normalize_scale.py index 45db8440b9..b62a73a8cb 100644 --- a/src/spikeinterface/preprocessing/tests/test_normalize_scale.py +++ b/src/spikeinterface/preprocessing/tests/test_normalize_scale.py @@ -30,7 +30,7 @@ def test_normalize_by_quantile(): rec2.save(verbose=False) # import matplotlib.pyplot as plt - # from spikeinterface.widgets import plot_timeseries + # from spikeinterface.widgets import plot_traces # fig, ax = plt.subplots() # ax.plot(rec.get_traces(segment_index=0)[:, 0], color='g') # ax.plot(rec2.get_traces(segment_index=0)[:, 0], color='r') diff --git a/src/spikeinterface/preprocessing/tests/test_phase_shift.py b/src/spikeinterface/preprocessing/tests/test_phase_shift.py index 41293b6c25..b1ccc433b3 100644 --- a/src/spikeinterface/preprocessing/tests/test_phase_shift.py +++ b/src/spikeinterface/preprocessing/tests/test_phase_shift.py @@ -104,9 +104,9 @@ def test_phase_shift(): # ~ import matplotlib.pyplot as plt # ~ import spikeinterface.full as si - # ~ si.plot_timeseries(rec, segment_index=0, time_range=[0, 10]) - # ~ si.plot_timeseries(rec2, segment_index=0, time_range=[0, 10]) - # ~ si.plot_timeseries(rec3, segment_index=0, time_range=[0, 10]) + # ~ si.plot_traces(rec, segment_index=0, time_range=[0, 10]) + # ~ si.plot_traces(rec2, segment_index=0, time_range=[0, 10]) + # ~ si.plot_traces(rec3, segment_index=0, time_range=[0, 10]) # ~ plt.show() diff --git a/src/spikeinterface/preprocessing/tests/test_rectify.py b/src/spikeinterface/preprocessing/tests/test_rectify.py index d4f58d3cc3..cca41ebf7d 100644 --- a/src/spikeinterface/preprocessing/tests/test_rectify.py +++ b/src/spikeinterface/preprocessing/tests/test_rectify.py @@ -27,7 +27,7 @@ def test_rectify(): assert traces.shape[1] == 1 # import matplotlib.pyplot as plt - # from spikeinterface.widgets import plot_timeseries + # from spikeinterface.widgets import plot_traces # fig, ax = plt.subplots() # ax.plot(rec.get_traces(segment_index=0)[:, 0], color='g') # ax.plot(rec2.get_traces(segment_index=0)[:, 0], color='r') diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py index b5ad24a5b3..e1a8ade22b 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py @@ -455,7 +455,7 @@ def plot_figure_1(benchmark, mode="average", cell_ind="auto"): ) print(benchmark.recording) - # si.plot_timeseries(benchmark.recording, mode='line', time_range=(times[0]-0.01, times[0] + 0.1), channel_ids=benchmark.recording.channel_ids[:20], ax=axs[0, 1]) + # si.plot_traces(benchmark.recording, mode='line', time_range=(times[0]-0.01, times[0] + 0.1), channel_ids=benchmark.recording.channel_ids[:20], ax=axs[0, 1]) # axs[0, 1].set_ylabel('Neurons') # si.plot_spikes_on_traces(benchmark.waveforms, unit_ids=[unit_id], time_range=(times[0]-0.01, times[0] + 0.1), unit_colors={unit_id : 'r'}, ax=axs[0, 1], diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py index 06f68a754e..81f2e4009b 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py @@ -1,5 +1,5 @@ # basics -# from .timeseries import plot_timeseries, TimeseriesWidget +# from .timeseries import plot_timeseries, TracesWidget from .rasters import plot_rasters, RasterWidget from .probemap import plot_probe_map, ProbeMapWidget diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/amplitudes.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/amplitudes.py index 37bfab9d66..dd7c801e9c 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/amplitudes.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/amplitudes.py @@ -31,7 +31,7 @@ def plot(self): self._do_plot() -class AmplitudeTimeseriesWidget(AmplitudeBaseWidget): +class AmplitudeTracesWidget(AmplitudeBaseWidget): """ Plots waveform amplitudes distribution. @@ -130,12 +130,12 @@ def _do_plot(self): def plot_amplitudes_timeseries(*args, **kwargs): - W = AmplitudeTimeseriesWidget(*args, **kwargs) + W = AmplitudeTracesWidget(*args, **kwargs) W.plot() return W -plot_amplitudes_timeseries.__doc__ = AmplitudeTimeseriesWidget.__doc__ +plot_amplitudes_timeseries.__doc__ = AmplitudeTracesWidget.__doc__ def plot_amplitudes_distribution(*args, **kwargs): diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/timeseries_.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/timeseries_.py index 5856549da3..ab6fa2ace5 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/timeseries_.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/timeseries_.py @@ -6,7 +6,7 @@ import scipy.spatial -class TimeseriesWidget(BaseWidget): +class TracesWidget(BaseWidget): """ Plots recording timeseries. @@ -46,7 +46,7 @@ class TimeseriesWidget(BaseWidget): Returns ------- - W: TimeseriesWidget + W: TracesWidget The output widget """ @@ -225,9 +225,9 @@ def _initialize_stats(self): def plot_timeseries(*args, **kwargs): - W = TimeseriesWidget(*args, **kwargs) + W = TracesWidget(*args, **kwargs) W.plot() return W -plot_timeseries.__doc__ = TimeseriesWidget.__doc__ +plot_timeseries.__doc__ = TracesWidget.__doc__ diff --git a/src/spikeinterface/widgets/spikes_on_traces.py b/src/spikeinterface/widgets/spikes_on_traces.py index ab4e629a2e..e7bcff0832 100644 --- a/src/spikeinterface/widgets/spikes_on_traces.py +++ b/src/spikeinterface/widgets/spikes_on_traces.py @@ -2,7 +2,7 @@ from .base import BaseWidget, to_attr from .utils import get_unit_colors -from .timeseries import TimeseriesWidget +from .traces import TracesWidget from ..core import ChannelSparsity from ..core.template_tools import get_template_extremum_channel from ..core.waveform_extractor import WaveformExtractor @@ -150,7 +150,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): sorting = we.sorting # first plot time series - ts_widget = TimeseriesWidget(recording, **dp.options, backend="matplotlib", **backend_kwargs) + ts_widget = TracesWidget(recording, **dp.options, backend="matplotlib", **backend_kwargs) self.ax = ts_widget.ax self.axes = ts_widget.axes self.figure = ts_widget.figure @@ -257,7 +257,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): width_cm = backend_kwargs["width_cm"] # plot timeseries - ts_widget = TimeseriesWidget(we.recording, **dp.options, backend="ipywidgets", **backend_kwargs_ts) + ts_widget = TracesWidget(we.recording, **dp.options, backend="ipywidgets", **backend_kwargs_ts) self.ax = ts_widget.ax self.axes = ts_widget.axes self.figure = ts_widget.figure diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 610da470e8..96c6ab80eb 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -86,16 +86,16 @@ def setUpClass(cls): cls.gt_comp = sc.compare_sorter_to_ground_truth(cls.sorting, cls.sorting) - def test_plot_timeseries(self): - possible_backends = list(sw.TimeseriesWidget.get_possible_backends()) + def test_plot_traces(self): + possible_backends = list(sw.TracesWidget.get_possible_backends()) for backend in possible_backends: if ON_GITHUB and backend == "sortingview": continue if backend not in self.skip_backends: - sw.plot_timeseries( + sw.plot_traces( self.recording, mode="map", show_channel_ids=True, backend=backend, **self.backend_kwargs[backend] ) - sw.plot_timeseries( + sw.plot_traces( self.recording, mode="map", show_channel_ids=True, @@ -105,8 +105,8 @@ def test_plot_timeseries(self): ) if backend != "sortingview": - sw.plot_timeseries(self.recording, mode="auto", backend=backend, **self.backend_kwargs[backend]) - sw.plot_timeseries( + sw.plot_traces(self.recording, mode="auto", backend=backend, **self.backend_kwargs[backend]) + sw.plot_traces( self.recording, mode="line", show_channel_ids=True, @@ -114,7 +114,7 @@ def test_plot_timeseries(self): **self.backend_kwargs[backend], ) # multi layer - sw.plot_timeseries( + sw.plot_traces( {"rec0": self.recording, "rec1": scale(self.recording, gain=0.8, offset=0)}, color="r", mode="line", @@ -337,7 +337,7 @@ def test_sorting_summary(self): # mytest.test_plot_unit_waveforms_density_map() # mytest.test_plot_unit_summary() # mytest.test_plot_all_amplitudes_distributions() - # mytest.test_plot_timeseries() + # mytest.test_plot_traces() # mytest.test_plot_unit_waveforms() # mytest.test_plot_unit_templates() # mytest.test_plot_unit_templates() diff --git a/src/spikeinterface/widgets/timeseries.py b/src/spikeinterface/widgets/traces.py similarity index 98% rename from src/spikeinterface/widgets/timeseries.py rename to src/spikeinterface/widgets/traces.py index 9439694639..53f1593260 100644 --- a/src/spikeinterface/widgets/timeseries.py +++ b/src/spikeinterface/widgets/traces.py @@ -7,7 +7,7 @@ from .utils import get_some_colors, array_to_image -class TimeseriesWidget(BaseWidget): +class TracesWidget(BaseWidget): """ Plots recording timeseries. @@ -54,7 +54,7 @@ class TimeseriesWidget(BaseWidget): Returns ------- - W: TimeseriesWidget + W: TracesWidget The output widget """ @@ -90,7 +90,7 @@ def __init__( recordings = {f"rec{i}": rec for i, rec in enumerate(recording)} rec0 = recordings[0] else: - raise ValueError("plot_timeseries recording must be recording or dict or list") + raise ValueError("plot_traces recording must be recording or dict or list") layer_keys = list(recordings.keys()) @@ -256,7 +256,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax.legend(loc="upper right") elif dp.mode == "map": - assert len(dp.list_traces) == 1, 'plot_timeseries with mode="map" do not support multi recording' + assert len(dp.list_traces) == 1, 'plot_traces with mode="map" do not support multi recording' assert len(dp.clims) == 1 clim = list(dp.clims.values())[0] extent = (dp.time_range[0], dp.time_range[1], min_y, max_y) @@ -501,7 +501,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs): backend_kwargs = self.update_backend_kwargs(**backend_kwargs) dp = to_attr(data_plot) - assert dp.mode == "map", 'sortingview plot_timeseries is only mode="map"' + assert dp.mode == "map", 'sortingview plot_traces is only mode="map"' if not dp.order_channel_by_depth: warnings.warn( diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index eab0345d53..f3c640ff16 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -1,3 +1,5 @@ +import warnings + from .base import backend_kwargs_desc from .all_amplitudes_distributions import AllAmplitudesDistributionsWidget @@ -11,7 +13,7 @@ from .spikes_on_traces import SpikesOnTracesWidget from .template_metrics import TemplateMetricsWidget from .template_similarity import TemplateSimilarityWidget -from .timeseries import TimeseriesWidget +from .traces import TracesWidget from .unit_depths import UnitDepthsWidget from .unit_locations import UnitLocationsWidget from .unit_summary import UnitSummaryWidget @@ -32,7 +34,7 @@ SpikesOnTracesWidget, TemplateMetricsWidget, TemplateSimilarityWidget, - TimeseriesWidget, + TracesWidget, UnitDepthsWidget, UnitLocationsWidget, UnitSummaryWidget, @@ -79,10 +81,15 @@ plot_spikes_on_traces = SpikesOnTracesWidget plot_template_metrics = TemplateMetricsWidget plot_template_similarity = TemplateSimilarityWidget -plot_timeseries = TimeseriesWidget +plot_traces = TracesWidget plot_unit_depths = UnitDepthsWidget plot_unit_locations = UnitLocationsWidget plot_unit_summary = UnitSummaryWidget plot_unit_templates = UnitTemplatesWidget plot_unit_waveforms_density_map = UnitWaveformDensityMapWidget plot_unit_waveforms = UnitWaveformsWidget + + +def plot_timeseries(*args, **kwargs): + warnings.warn("plot_timeseries() is now plot_traces()") + return plot_traces(*args, **kwargs) From 019a5c8d59ec8b696c3c8f737b2d38c0574b6bc9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 20 Jul 2023 09:28:01 +0000 Subject: [PATCH 25/31] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/widgets/base.py | 7 +++---- src/spikeinterface/widgets/utils_sortingview.py | 1 - 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/widgets/base.py b/src/spikeinterface/widgets/base.py index eaa151ccd9..dea46b8f51 100644 --- a/src/spikeinterface/widgets/base.py +++ b/src/spikeinterface/widgets/base.py @@ -77,10 +77,9 @@ def __init__( self.do_plot() # subclass must define one method per supported backend: - # def plot_matplotlib(self, data_plot, **backend_kwargs): - # def plot_ipywidgets(self, data_plot, **backend_kwargs): - # def plot_sortingview(self, data_plot, **backend_kwargs): - + # def plot_matplotlib(self, data_plot, **backend_kwargs): + # def plot_ipywidgets(self, data_plot, **backend_kwargs): + # def plot_sortingview(self, data_plot, **backend_kwargs): @classmethod def get_possible_backends(cls): diff --git a/src/spikeinterface/widgets/utils_sortingview.py b/src/spikeinterface/widgets/utils_sortingview.py index 24ae481a6b..50bbab99df 100644 --- a/src/spikeinterface/widgets/utils_sortingview.py +++ b/src/spikeinterface/widgets/utils_sortingview.py @@ -3,7 +3,6 @@ from ..core.core_tools import check_json - def make_serializable(*args): dict_to_serialize = {int(i): a for i, a in enumerate(args)} serializable_dict = check_json(dict_to_serialize) From 9f6636b7320f01aaa9ccd81a54faabfe4f6365dd Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 20 Jul 2023 11:32:03 +0200 Subject: [PATCH 26/31] Remove unecessary legacy widgets are are alreayd ported --- .../widgets/_legacy_mpl_widgets/__init__.py | 12 - .../widgets/_legacy_mpl_widgets/amplitudes.py | 147 ------------ .../_legacy_mpl_widgets/correlograms_.py | 107 --------- .../_legacy_mpl_widgets/depthamplitude.py | 58 ----- .../_legacy_mpl_widgets/unitlocalization_.py | 109 --------- .../_legacy_mpl_widgets/unitsummary.py | 104 --------- .../unitwaveformdensitymap_.py | 199 ---------------- .../_legacy_mpl_widgets/unitwaveforms_.py | 218 ------------------ 8 files changed, 954 deletions(-) delete mode 100644 src/spikeinterface/widgets/_legacy_mpl_widgets/amplitudes.py delete mode 100644 src/spikeinterface/widgets/_legacy_mpl_widgets/correlograms_.py delete mode 100644 src/spikeinterface/widgets/_legacy_mpl_widgets/depthamplitude.py delete mode 100644 src/spikeinterface/widgets/_legacy_mpl_widgets/unitlocalization_.py delete mode 100644 src/spikeinterface/widgets/_legacy_mpl_widgets/unitsummary.py delete mode 100644 src/spikeinterface/widgets/_legacy_mpl_widgets/unitwaveformdensitymap_.py delete mode 100644 src/spikeinterface/widgets/_legacy_mpl_widgets/unitwaveforms_.py diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py index 81f2e4009b..c0dcd7ea6e 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py @@ -6,25 +6,15 @@ # isi/ccg/acg from .isidistribution import plot_isi_distribution, ISIDistributionWidget -# from .correlograms import (plot_crosscorrelograms, CrossCorrelogramsWidget, -# plot_autocorrelograms, AutoCorrelogramsWidget) - # peak activity from .activity import plot_peak_activity_map, PeakActivityMapWidget # waveform/PC related -# from .unitwaveforms import plot_unit_waveforms, plot_unit_templates -# from .unitwaveformdensitymap import plot_unit_waveform_density_map, UnitWaveformDensityMapWidget -# from .amplitudes import plot_amplitudes_distribution from .principalcomponent import plot_principal_component -# from .unitlocalization import plot_unit_localization, UnitLocalizationWidget - # units on probe from .unitprobemap import plot_unit_probe_map, UnitProbeMapWidget -# from .depthamplitude import plot_units_depth_vs_amplitude - # comparison related from .confusionmatrix import plot_confusion_matrix, ConfusionMatrixWidget from .agreementmatrix import plot_agreement_matrix, AgreementMatrixWidget @@ -77,8 +67,6 @@ ComparisonPerformancesByTemplateSimilarity, ) -# unit summary -# from .unitsummary import plot_unit_summary, UnitSummaryWidget # unit presence from .presence import plot_presence, PresenceWidget diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/amplitudes.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/amplitudes.py deleted file mode 100644 index dd7c801e9c..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/amplitudes.py +++ /dev/null @@ -1,147 +0,0 @@ -import numpy as np -from matplotlib import pyplot as plt - -from .basewidget import BaseWidget - -from ...postprocessing import compute_spike_amplitudes -from .utils import get_unit_colors - - -class AmplitudeBaseWidget(BaseWidget): - def __init__(self, waveform_extractor, unit_ids=None, compute_kwargs={}, unit_colors=None, figure=None, ax=None): - BaseWidget.__init__(self, figure, ax) - - self.we = waveform_extractor - - if self.we.is_extension("spike_amplitudes"): - sac = self.we.load_extension("spike_amplitudes") - self.amplitudes = sac.get_data(outputs="by_unit") - else: - self.amplitudes = compute_spike_amplitudes(self.we, outputs="by_unit", **compute_kwargs) - - if unit_ids is None: - unit_ids = waveform_extractor.sorting.unit_ids - self.unit_ids = unit_ids - - if unit_colors is None: - unit_colors = get_unit_colors(self.we.sorting) - self.unit_colors = unit_colors - - def plot(self): - self._do_plot() - - -class AmplitudeTracesWidget(AmplitudeBaseWidget): - """ - Plots waveform amplitudes distribution. - - Parameters - ---------- - waveform_extractor: WaveformExtractor - - amplitudes: None or pre computed amplitudes - If None then amplitudes are recomputed - peak_sign: 'neg', 'pos', 'both' - In case of recomputing amplitudes. - - Returns - ------- - W: AmplitudeDistributionWidget - The output widget - """ - - def _do_plot(self): - sorting = self.we.sorting - # ~ unit_ids = sorting.unit_ids - num_seg = sorting.get_num_segments() - fs = sorting.get_sampling_frequency() - - # TODO handle segment - ax = self.ax - for i, unit_id in enumerate(self.unit_ids): - for segment_index in range(num_seg): - times = sorting.get_unit_spike_train(unit_id, segment_index=segment_index) - times = times / fs - amps = self.amplitudes[segment_index][unit_id] - ax.scatter(times, amps, color=self.unit_colors[unit_id], s=3, alpha=1) - - if i == 0: - ax.set_title(f"segment {segment_index}") - if i == len(self.unit_ids) - 1: - ax.set_xlabel("Times [s]") - if segment_index == 0: - ax.set_ylabel(f"Amplitude") - - ylims = ax.get_ylim() - if np.max(ylims) < 0: - ax.set_ylim(min(ylims), 0) - if np.min(ylims) > 0: - ax.set_ylim(0, max(ylims)) - - -class AmplitudeDistributionWidget(AmplitudeBaseWidget): - """ - Plots waveform amplitudes distribution. - - Parameters - ---------- - waveform_extractor: WaveformExtractor - - amplitudes: None or pre computed amplitudes - If None then amplitudes are recomputed - peak_sign: 'neg', 'pos', 'both' - In case of recomputing amplitudes. - - Returns - ------- - W: AmplitudeDistributionWidget - The output widget - """ - - def _do_plot(self): - sorting = self.we.sorting - unit_ids = sorting.unit_ids - num_seg = sorting.get_num_segments() - - ax = self.ax - unit_amps = [] - for i, unit_id in enumerate(unit_ids): - amps = [] - for segment_index in range(num_seg): - amps.append(self.amplitudes[segment_index][unit_id]) - amps = np.concatenate(amps) - unit_amps.append(amps) - parts = ax.violinplot(unit_amps, showmeans=False, showmedians=False, showextrema=False) - - for i, pc in enumerate(parts["bodies"]): - color = self.unit_colors[unit_ids[i]] - pc.set_facecolor(color) - pc.set_edgecolor("black") - pc.set_alpha(1) - - ax.set_xticks(np.arange(len(unit_ids)) + 1) - ax.set_xticklabels([str(unit_id) for unit_id in unit_ids]) - - ylims = ax.get_ylim() - if np.max(ylims) < 0: - ax.set_ylim(min(ylims), 0) - if np.min(ylims) > 0: - ax.set_ylim(0, max(ylims)) - - -def plot_amplitudes_timeseries(*args, **kwargs): - W = AmplitudeTracesWidget(*args, **kwargs) - W.plot() - return W - - -plot_amplitudes_timeseries.__doc__ = AmplitudeTracesWidget.__doc__ - - -def plot_amplitudes_distribution(*args, **kwargs): - W = AmplitudeDistributionWidget(*args, **kwargs) - W.plot() - return W - - -plot_amplitudes_distribution.__doc__ = AmplitudeDistributionWidget.__doc__ diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/correlograms_.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/correlograms_.py deleted file mode 100644 index 8e12559066..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/correlograms_.py +++ /dev/null @@ -1,107 +0,0 @@ -import numpy as np -from matplotlib import pyplot as plt -from .basewidget import BaseWidget - -from spikeinterface.postprocessing import compute_correlograms - - -class CrossCorrelogramsWidget(BaseWidget): - """ - Plots spike train cross-correlograms. - The diagonal is auto-correlogram. - - Parameters - ---------- - sorting: SortingExtractor - The sorting extractor object - unit_ids: list - List of unit ids - bin_ms: float - bins duration in ms - window_ms: float - Window duration in ms - """ - - def __init__(self, sorting, unit_ids=None, window_ms=100.0, bin_ms=1.0, axes=None): - if unit_ids is not None: - sorting = sorting.select_units(unit_ids) - self.sorting = sorting - self.compute_kwargs = dict(window_ms=window_ms, bin_ms=bin_ms) - - if axes is None: - n = len(sorting.unit_ids) - fig, axes = plt.subplots(nrows=n, ncols=n, sharex=True, sharey=True) - BaseWidget.__init__(self, None, None, axes) - - def plot(self): - correlograms, bins = compute_correlograms(self.sorting, **self.compute_kwargs) - bin_width = bins[1] - bins[0] - unit_ids = self.sorting.unit_ids - for i, unit_id1 in enumerate(unit_ids): - for j, unit_id2 in enumerate(unit_ids): - ccg = correlograms[i, j] - ax = self.axes[i, j] - if i == j: - color = "g" - else: - color = "k" - ax.bar(x=bins[:-1], height=ccg, width=bin_width, color=color, align="edge") - - for i, unit_id in enumerate(unit_ids): - self.axes[0, i].set_title(str(unit_id)) - self.axes[-1, i].set_xlabel("CCG (ms)") - - -def plot_crosscorrelograms(*args, **kwargs): - W = CrossCorrelogramsWidget(*args, **kwargs) - W.plot() - return W - - -plot_crosscorrelograms.__doc__ = CrossCorrelogramsWidget.__doc__ - - -class AutoCorrelogramsWidget(BaseWidget): - """ - Plots spike train auto-correlograms. - Parameters - ---------- - sorting: SortingExtractor - The sorting extractor object - unit_ids: list - List of unit ids - bin_ms: float - bins duration in ms - window_ms: float - Window duration in ms - """ - - def __init__(self, sorting, unit_ids=None, window_ms=100.0, bin_ms=1.0, ncols=5, axes=None): - if unit_ids is not None: - sorting = sorting.select_units(unit_ids) - self.sorting = sorting - self.compute_kwargs = dict(window_ms=window_ms, bin_ms=bin_ms) - - if axes is None: - num_axes = len(sorting.unit_ids) - BaseWidget.__init__(self, None, None, axes, ncols=ncols, num_axes=num_axes) - - def plot(self): - correlograms, bins = compute_correlograms(self.sorting, **self.compute_kwargs) - bin_width = bins[1] - bins[0] - unit_ids = self.sorting.unit_ids - for i, unit_id in enumerate(unit_ids): - ccg = correlograms[i, i] - ax = self.axes.flatten()[i] - color = "g" - ax.bar(x=bins[:-1], height=ccg, width=bin_width, color=color, align="edge") - ax.set_title(str(unit_id)) - - -def plot_autocorrelograms(*args, **kwargs): - W = AutoCorrelogramsWidget(*args, **kwargs) - W.plot() - return W - - -plot_autocorrelograms.__doc__ = AutoCorrelogramsWidget.__doc__ diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/depthamplitude.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/depthamplitude.py deleted file mode 100644 index a382fee9bc..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/depthamplitude.py +++ /dev/null @@ -1,58 +0,0 @@ -import numpy as np -from matplotlib import pyplot as plt - -from .basewidget import BaseWidget - -from ...postprocessing import get_template_extremum_channel, get_template_extremum_amplitude -from .utils import get_unit_colors - - -class UnitsDepthAmplitudeWidget(BaseWidget): - def __init__(self, waveform_extractor, peak_sign="neg", depth_axis=1, unit_colors=None, figure=None, ax=None): - BaseWidget.__init__(self, figure, ax) - - self.we = waveform_extractor - self.peak_sign = peak_sign - self.depth_axis = depth_axis - if unit_colors is None: - unit_colors = get_unit_colors(self.we.sorting) - self.unit_colors = unit_colors - - def plot(self): - ax = self.ax - we = self.we - unit_ids = we.unit_ids - - channels_index = get_template_extremum_channel(we, peak_sign=self.peak_sign, outputs="index") - contact_positions = we.get_channel_locations() - - channel_depth = contact_positions[:, self.depth_axis] - unit_depth = [channel_depth[channels_index[unit_id]] for unit_id in unit_ids] - - unit_amplitude = get_template_extremum_amplitude(we, peak_sign=self.peak_sign) - unit_amplitude = np.abs([unit_amplitude[unit_id] for unit_id in unit_ids]) - - colors = [self.unit_colors[unit_id] for unit_id in unit_ids] - - num_spikes = np.zeros(len(unit_ids)) - for i, unit_id in enumerate(unit_ids): - for segment_index in range(we.get_num_segments()): - st = we.sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) - num_spikes[i] += st.size - - size = num_spikes / max(num_spikes) * 120 - ax.scatter(unit_amplitude, unit_depth, color=colors, s=size) - - ax.set_aspect(3) - ax.set_xlabel("amplitude") - ax.set_ylabel("depth [um]") - ax.set_xlim(0, max(unit_amplitude) * 1.2) - - -def plot_units_depth_vs_amplitude(*args, **kwargs): - W = UnitsDepthAmplitudeWidget(*args, **kwargs) - W.plot() - return W - - -plot_units_depth_vs_amplitude.__doc__ = UnitsDepthAmplitudeWidget.__doc__ diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/unitlocalization_.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/unitlocalization_.py deleted file mode 100644 index a2b8beea3f..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/unitlocalization_.py +++ /dev/null @@ -1,109 +0,0 @@ -import numpy as np -import matplotlib.pylab as plt -from .basewidget import BaseWidget - -from probeinterface.plotting import plot_probe - -from spikeinterface.postprocessing import compute_unit_locations - -from .utils import get_unit_colors - - -class UnitLocalizationWidget(BaseWidget): - """ - Plot unit localization on probe. - - Parameters - ---------- - waveform_extractor: WaveformaExtractor - WaveformaExtractorr object - peaks: None or numpy array - Optionally can give already detected peaks - to avoid multiple computation. - method: str default 'center_of_mass' - Method used to estimate unit localization if 'unit_location' is None - method_kwargs: dict - Option for the method - unit_colors: None or dict - A dict key is unit_id and value is any color format handled by matplotlib. - If None, then the get_unit_colors() is internally used. - with_channel_ids: bool False default - add channel ids text on the probe - figure: matplotlib figure - The figure to be used. If not given a figure is created - ax: matplotlib axis - The axis to be used. If not given an axis is created - - Returns - ------- - W: ProbeMapWidget - The output widget - """ - - def __init__( - self, - waveform_extractor, - method="center_of_mass", - method_kwargs={}, - unit_colors=None, - with_channel_ids=False, - figure=None, - ax=None, - ): - BaseWidget.__init__(self, figure, ax) - - self.waveform_extractor = waveform_extractor - self.method = method - self.method_kwargs = method_kwargs - - if unit_colors is None: - unit_colors = get_unit_colors(waveform_extractor.sorting) - self.unit_colors = unit_colors - - self.with_channel_ids = with_channel_ids - - def plot(self): - we = self.waveform_extractor - unit_ids = we.unit_ids - - if we.is_extension("unit_locations"): - unit_locations = we.load_extension("unit_locations").get_data() - else: - unit_locations = compute_unit_locations(we, method=self.method, **self.method_kwargs) - - ax = self.ax - probegroup = we.get_probegroup() - probe_shape_kwargs = dict(facecolor="w", edgecolor="k", lw=0.5, alpha=1.0) - contacts_kargs = dict(alpha=1.0, edgecolor="k", lw=0.5) - - for probe in probegroup.probes: - text_on_contact = None - if self.with_channel_ids: - text_on_contact = self.waveform_extractor.recording.channel_ids - - poly_contact, poly_contour = plot_probe( - probe, - ax=ax, - contacts_colors="w", - contacts_kargs=contacts_kargs, - probe_shape_kwargs=probe_shape_kwargs, - text_on_contact=text_on_contact, - ) - poly_contact.set_zorder(2) - if poly_contour is not None: - poly_contour.set_zorder(1) - - ax.set_title("") - - color = np.array([self.unit_colors[unit_id] for unit_id in unit_ids]) - loc = ax.scatter(unit_locations[:, 0], unit_locations[:, 1], marker="1", color=color, s=80, lw=3) - loc.set_zorder(3) - - -def plot_unit_localization(*args, **kwargs): - W = UnitLocalizationWidget(*args, **kwargs) - W.plot() - return W - - -plot_unit_localization.__doc__ = UnitLocalizationWidget.__doc__ diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/unitsummary.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/unitsummary.py deleted file mode 100644 index a1d0589abc..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/unitsummary.py +++ /dev/null @@ -1,104 +0,0 @@ -import numpy as np -from matplotlib import pyplot as plt -from .basewidget import BaseWidget - -from .utils import get_unit_colors - -from .unitprobemap import plot_unit_probe_map -from .unitwaveformdensitymap_ import plot_unit_waveform_density_map -from .amplitudes import plot_amplitudes_timeseries -from .unitwaveforms_ import plot_unit_waveforms -from .isidistribution import plot_isi_distribution - - -class UnitSummaryWidget(BaseWidget): - """ - Plot a unit summary. - - If amplitudes are alreday computed they are displayed. - - Parameters - ---------- - waveform_extractor: WaveformExtractor - The waveform extractor object - unit_id: into or str - The unit id to plot the summary of - unit_colors: list or None - Optional matplotlib color for the unit - figure: matplotlib figure - The figure to be used. If not given a figure is created - ax: matplotlib axis - The axis to be used. If not given an axis is created - - Returns - ------- - W: UnitSummaryWidget - The output widget - """ - - def __init__(self, waveform_extractor, unit_id, unit_colors=None, figure=None, ax=None): - assert ax is None - # ~ assert axes is None - - if figure is None: - figure = plt.figure( - constrained_layout=False, - figsize=(15, 7), - ) - - BaseWidget.__init__(self, figure, None) - - self.waveform_extractor = waveform_extractor - self.recording = waveform_extractor.recording - self.sorting = waveform_extractor.sorting - self.unit_id = unit_id - - if unit_colors is None: - unit_colors = get_unit_colors(self.sorting) - self.unit_colors = unit_colors - - def plot(self): - we = self.waveform_extractor - - fig = self.figure - self.ax.remove() - - if we.is_extension("spike_amplitudes"): - nrows = 3 - else: - nrows = 2 - - gs = fig.add_gridspec(nrows, 6) - - ax = fig.add_subplot(gs[:, 0]) - plot_unit_probe_map(we, unit_ids=[self.unit_id], axes=[ax], colorbar=False) - ax.set_title("") - - ax = fig.add_subplot(gs[0:2, 1:3]) - plot_unit_waveforms(we, unit_ids=[self.unit_id], radius_um=60, axes=[ax], unit_colors=self.unit_colors) - ax.set_title(None) - - ax = fig.add_subplot(gs[0:2, 3:5]) - plot_unit_waveform_density_map(we, unit_ids=[self.unit_id], max_channels=1, ax=ax, same_axis=True) - ax.set_ylabel(None) - - ax = fig.add_subplot(gs[0:2, 5]) - plot_isi_distribution(we.sorting, unit_ids=[self.unit_id], axes=[ax]) - ax.set_title("") - - if we.is_extension("spike_amplitudes"): - ax = fig.add_subplot(gs[-1, 1:]) - plot_amplitudes_timeseries(we, unit_ids=[self.unit_id], ax=ax, unit_colors=self.unit_colors) - ax.set_ylabel(None) - ax.set_title(None) - - fig.suptitle(f"Unit ID: {self.unit_id}") - - -def plot_unit_summary(*args, **kwargs): - W = UnitSummaryWidget(*args, **kwargs) - W.plot() - return W - - -plot_unit_summary.__doc__ = UnitSummaryWidget.__doc__ diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/unitwaveformdensitymap_.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/unitwaveformdensitymap_.py deleted file mode 100644 index c5cbe07a7b..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/unitwaveformdensitymap_.py +++ /dev/null @@ -1,199 +0,0 @@ -import numpy as np -from matplotlib import pyplot as plt - -from .basewidget import BaseWidget -from .utils import get_unit_colors -from ...postprocessing import get_template_channel_sparsity - - -class UnitWaveformDensityMapWidget(BaseWidget): - """ - Plots unit waveforms using heat map density. - - Parameters - ---------- - waveform_extractor: WaveformExtractor - channel_ids: list - The channel ids to display - unit_ids: list - List of unit ids. - plot_templates: bool - If True, templates are plotted over the waveforms - max_channels : None or int - If not None only max_channels are displayed per units. - Incompatible with with `radius_um` - radius_um: None or float - If not None, all channels within a circle around the peak waveform will be displayed - Incompatible with with `max_channels` - unit_colors: None or dict - A dict key is unit_id and value is any color format handled by matplotlib. - If None, then the get_unit_colors() is internally used. - same_axis: bool - If True then all density are plot on the same axis and then channels is the union - all channel per units. - set_title: bool - Create a plot title with the unit number if True. - plot_channels: bool - Plot channel locations below traces, only used if channel_locs is True - """ - - def __init__( - self, - waveform_extractor, - channel_ids=None, - unit_ids=None, - max_channels=None, - radius_um=None, - same_axis=False, - unit_colors=None, - ax=None, - axes=None, - ): - self.waveform_extractor = waveform_extractor - self.recording = waveform_extractor.recording - self.sorting = waveform_extractor.sorting - - if unit_ids is None: - unit_ids = self.sorting.get_unit_ids() - self.unit_ids = unit_ids - - if channel_ids is None: - channel_ids = self.recording.get_channel_ids() - self.channel_ids = channel_ids - - if unit_colors is None: - unit_colors = get_unit_colors(self.sorting) - self.unit_colors = unit_colors - - if radius_um is not None: - assert max_channels is None, "radius_um and max_channels are mutually exclusive" - if max_channels is not None: - assert radius_um is None, "radius_um and max_channels are mutually exclusive" - - self.radius_um = radius_um - self.max_channels = max_channels - self.same_axis = same_axis - - if axes is None and ax is None: - if same_axis: - fig, ax = plt.subplots() - axes = None - else: - nrows = len(unit_ids) - fig, axes = plt.subplots(nrows=nrows, squeeze=False) - axes = axes[:, 0] - ax = None - BaseWidget.__init__(self, figure=None, ax=ax, axes=axes) - - def plot(self): - we = self.waveform_extractor - - # channel sparsity - if self.radius_um is not None: - channel_inds = get_template_channel_sparsity(we, method="radius", outputs="index", radius_um=self.radius_um) - elif self.max_channels is not None: - channel_inds = get_template_channel_sparsity( - we, method="best_channels", outputs="index", num_channels=self.max_channels - ) - else: - # all channels - channel_inds = {unit_id: np.arange(len(self.channel_ids)) for unit_id in self.unit_ids} - channel_inds = {unit_id: inds for unit_id, inds in channel_inds.items() if unit_id in self.unit_ids} - - if self.same_axis: - # channel union - inds = np.unique(np.concatenate([inds.tolist() for inds in channel_inds.values()])) - channel_inds = {unit_id: inds for unit_id in self.unit_ids} - - # bins - templates = we.get_all_templates(unit_ids=self.unit_ids, mode="median") - bin_min = np.min(templates) * 1.3 - bin_max = np.max(templates) * 1.3 - bin_size = (bin_max - bin_min) / 100 - bins = np.arange(bin_min, bin_max, bin_size) - - # 2d histograms - all_hist2d = None - for unit_index, unit_id in enumerate(self.unit_ids): - chan_inds = channel_inds[unit_id] - - wfs = we.get_waveforms(unit_id) - wfs = wfs[:, :, chan_inds] - - # make histogram density - wfs_flat = wfs.swapaxes(1, 2).reshape(wfs.shape[0], -1) - hist2d = np.zeros((wfs_flat.shape[1], bins.size)) - indexes0 = np.arange(wfs_flat.shape[1]) - - wf_bined = np.floor((wfs_flat - bin_min) / bin_size).astype("int32") - wf_bined = wf_bined.clip(0, bins.size - 1) - for d in wf_bined: - hist2d[indexes0, d] += 1 - - if self.same_axis: - if all_hist2d is None: - all_hist2d = hist2d - else: - all_hist2d += hist2d - else: - ax = self.axes[unit_index] - im = ax.imshow( - hist2d.T, - interpolation="nearest", - origin="lower", - aspect="auto", - extent=(0, hist2d.shape[0], bin_min, bin_max), - cmap="hot", - ) - - if self.same_axis: - ax = self.ax - im = ax.imshow( - all_hist2d.T, - interpolation="nearest", - origin="lower", - aspect="auto", - extent=(0, hist2d.shape[0], bin_min, bin_max), - cmap="hot", - ) - - # plot median - for unit_index, unit_id in enumerate(self.unit_ids): - if self.same_axis: - ax = self.ax - else: - ax = self.axes[unit_index] - chan_inds = channel_inds[unit_id] - template = templates[unit_index, :, chan_inds] - template_flat = template.flatten() - color = self.unit_colors[unit_id] - ax.plot(template_flat, color=color, lw=1) - - # final cosmetics - for unit_index, unit_id in enumerate(self.unit_ids): - if self.same_axis: - ax = self.ax - if unit_index != 0: - continue - else: - ax = self.axes[unit_index] - chan_inds = channel_inds[unit_id] - for i, chan_ind in enumerate(chan_inds): - if i != 0: - ax.axvline(i * wfs.shape[1], color="w", lw=3) - channel_id = self.recording.channel_ids[chan_ind] - x = i * wfs.shape[1] + wfs.shape[1] // 2 - y = (bin_max + bin_min) / 2.0 - ax.text(x, y, f"chan_id {channel_id}", color="w", ha="center", va="center") - - ax.set_xticks([]) - ax.set_ylabel(f"unit_id {unit_id}") - - -def plot_unit_waveform_density_map(*args, **kwargs): - W = UnitWaveformDensityMapWidget(*args, **kwargs) - W.plot() - return W - - -plot_unit_waveform_density_map.__doc__ = UnitWaveformDensityMapWidget.__doc__ diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/unitwaveforms_.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/unitwaveforms_.py deleted file mode 100644 index a1e28bbb82..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/unitwaveforms_.py +++ /dev/null @@ -1,218 +0,0 @@ -import numpy as np -from matplotlib import pyplot as plt - -from .basewidget import BaseWidget -from .utils import get_unit_colors -from ...postprocessing import get_template_channel_sparsity - - -class UnitWaveformsWidget(BaseWidget): - """ - Plots unit waveforms. - - Parameters - ---------- - waveform_extractor: WaveformExtractor - channel_ids: list - The channel ids to display - unit_ids: list - List of unit ids. - plot_templates: bool - If True, templates are plotted over the waveforms - radius_um: None or float - If not None, all channels within a circle around the peak waveform will be displayed - Incompatible with with `max_channels` - max_channels : None or int - If not None only max_channels are displayed per units. - Incompatible with with `radius_um` - set_title: bool - Create a plot title with the unit number if True. - plot_channels: bool - Plot channel locations below traces. - axis_equal: bool - Equal aspect ratio for x and y axis, to visualize the array geometry to scale. - lw: float - Line width for the traces. - unit_colors: None or dict - A dict key is unit_id and value is any color format handled by matplotlib. - If None, then the get_unit_colors() is internally used. - unit_selected_waveforms: None or dict - A dict key is unit_id and value is the subset of waveforms indices that should be - be displayed - show_all_channels: bool - Show the whole probe if True, or only selected channels if False - The axis to be used. If not given an axis is created - axes: list of matplotlib axes - The axes to be used for the individual plots. If not given the required axes are created. If provided, the ax - and figure parameters are ignored - """ - - def __init__( - self, - waveform_extractor, - channel_ids=None, - unit_ids=None, - plot_waveforms=True, - plot_templates=True, - plot_channels=False, - unit_colors=None, - max_channels=None, - radius_um=None, - ncols=5, - axes=None, - lw=2, - axis_equal=False, - unit_selected_waveforms=None, - set_title=True, - ): - self.waveform_extractor = waveform_extractor - self._recording = waveform_extractor.recording - self._sorting = waveform_extractor.sorting - sorting = waveform_extractor.sorting - - if unit_ids is None: - unit_ids = self._sorting.get_unit_ids() - self._unit_ids = unit_ids - if channel_ids is None: - channel_ids = self._recording.get_channel_ids() - self._channel_ids = channel_ids - - if unit_colors is None: - unit_colors = get_unit_colors(self._sorting) - self.unit_colors = unit_colors - - self.ncols = ncols - self._plot_waveforms = plot_waveforms - self._plot_templates = plot_templates - self._plot_channels = plot_channels - - if radius_um is not None: - assert max_channels is None, "radius_um and max_channels are mutually exclusive" - if max_channels is not None: - assert radius_um is None, "radius_um and max_channels are mutually exclusive" - - self.radius_um = radius_um - self.max_channels = max_channels - self.unit_selected_waveforms = unit_selected_waveforms - - # TODO - self._lw = lw - self._axis_equal = axis_equal - - self._set_title = set_title - - if axes is None: - num_axes = len(unit_ids) - else: - num_axes = None - BaseWidget.__init__(self, None, None, axes, ncols=ncols, num_axes=num_axes) - - def plot(self): - self._do_plot() - - def _do_plot(self): - we = self.waveform_extractor - unit_ids = self._unit_ids - channel_ids = self._channel_ids - - channel_locations = self._recording.get_channel_locations(channel_ids=channel_ids) - templates = we.get_all_templates(unit_ids=unit_ids) - - xvectors, y_scale, y_offset = get_waveforms_scales(we, templates, channel_locations) - - ncols = min(self.ncols, len(unit_ids)) - nrows = int(np.ceil(len(unit_ids) / ncols)) - - if self.radius_um is not None: - channel_inds = get_template_channel_sparsity(we, method="radius", outputs="index", radius_um=self.radius_um) - elif self.max_channels is not None: - channel_inds = get_template_channel_sparsity( - we, method="best_channels", outputs="index", num_channels=self.max_channels - ) - else: - # all channels - channel_inds = {unit_id: slice(None) for unit_id in unit_ids} - - for i, unit_id in enumerate(unit_ids): - ax = self.axes.flatten()[i] - color = self.unit_colors[unit_id] - - chan_inds = channel_inds[unit_id] - xvectors_flat = xvectors[:, chan_inds].T.flatten() - - # plot waveforms - if self._plot_waveforms: - wfs = we.get_waveforms(unit_id) - if self.unit_selected_waveforms is not None: - wfs = wfs[self.unit_selected_waveforms[unit_id]][:, :, chan_inds] - else: - wfs = wfs[:, :, chan_inds] - wfs = wfs * y_scale + y_offset[None, :, chan_inds] - wfs_flat = wfs.swapaxes(1, 2).reshape(wfs.shape[0], -1).T - ax.plot(xvectors_flat, wfs_flat, lw=1, alpha=0.3, color=color) - - # plot template - if self._plot_templates: - template = templates[i, :, :][:, chan_inds] * y_scale + y_offset[:, chan_inds] - if self._plot_waveforms and self._plot_templates: - color = "k" - ax.plot(xvectors_flat, template.T.flatten(), lw=1, color=color) - template_label = unit_ids[i] - ax.set_title(f"template {template_label}") - - # plot channels - if self._plot_channels: - # TODO enhance this - ax.scatter(channel_locations[:, 0], channel_locations[:, 1], color="k") - - -def get_waveforms_scales(we, templates, channel_locations): - """ - Return scales and x_vector for templates plotting - """ - wf_max = np.max(templates) - wf_min = np.max(templates) - - x_chans = np.unique(channel_locations[:, 0]) - if x_chans.size > 1: - delta_x = np.min(np.diff(x_chans)) - else: - delta_x = 40.0 - - y_chans = np.unique(channel_locations[:, 1]) - if y_chans.size > 1: - delta_y = np.min(np.diff(y_chans)) - else: - delta_y = 40.0 - - m = max(np.abs(wf_max), np.abs(wf_min)) - y_scale = delta_y / m * 0.7 - - y_offset = channel_locations[:, 1][None, :] - - xvect = delta_x * (np.arange(we.nsamples) - we.nbefore) / we.nsamples * 0.7 - - xvectors = channel_locations[:, 0][None, :] + xvect[:, None] - # put nan for discontinuity - xvectors[-1, :] = np.nan - - return xvectors, y_scale, y_offset - - -def plot_unit_waveforms(*args, **kwargs): - W = UnitWaveformsWidget(*args, **kwargs) - W.plot() - return W - - -plot_unit_waveforms.__doc__ = UnitWaveformsWidget.__doc__ - - -def plot_unit_templates(*args, **kwargs): - kwargs["plot_waveforms"] = False - W = UnitWaveformsWidget(*args, **kwargs) - W.plot() - return W - - -plot_unit_templates.__doc__ = UnitWaveformsWidget.__doc__ From 085a99f6045dbe896a2d560c027d4327dd2a19cd Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 20 Jul 2023 15:01:48 +0200 Subject: [PATCH 27/31] Fix plot_traces SV and add plot_motion to API --- doc/api.rst | 1 + src/spikeinterface/widgets/traces.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/api.rst b/doc/api.rst index 932c989c19..2e9fc1567a 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -269,6 +269,7 @@ spikeinterface.widgets .. autofunction:: plot_amplitudes .. autofunction:: plot_autocorrelograms .. autofunction:: plot_crosscorrelograms + .. autofunction:: plot_motion .. autofunction:: plot_quality_metrics .. autofunction:: plot_sorting_summary .. autofunction:: plot_spike_locations diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index 53f1593260..c9dc04811a 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -498,7 +498,6 @@ def plot_sortingview(self, data_plot, **backend_kwargs): except ImportError: raise ImportError("To use the timeseries in sorting view you need the pyvips package.") - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) dp = to_attr(data_plot) assert dp.mode == "map", 'sortingview plot_traces is only mode="map"' From aa3b7c47318552060aae9b13ae9c1e5b3db0a080 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 21 Jul 2023 09:47:28 +0200 Subject: [PATCH 28/31] fix plot_trace legend --- src/spikeinterface/widgets/traces.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index c9dc04811a..405c4b6b79 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -290,7 +290,8 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): check_ipywidget_backend() self.next_data_plot = data_plot.copy() - + self.next_data_plot["add_legend"] = False + recordings = data_plot["recordings"] # first layer From 1855b8dca1b929428be0560875d21af30f6fcf48 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 21 Jul 2023 07:48:13 +0000 Subject: [PATCH 29/31] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/widgets/traces.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index 405c4b6b79..9a2ec4a215 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -291,7 +291,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.next_data_plot = data_plot.copy() self.next_data_plot["add_legend"] = False - + recordings = data_plot["recordings"] # first layer From 734670510ef3d58f2943ec1e36b79b1e4f6b2a98 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 21 Jul 2023 14:13:04 +0200 Subject: [PATCH 30/31] feedback from Alessio --- src/spikeinterface/widgets/all_amplitudes_distributions.py | 3 --- src/spikeinterface/widgets/quality_metrics.py | 2 -- src/spikeinterface/widgets/sorting_summary.py | 2 -- src/spikeinterface/widgets/template_similarity.py | 2 -- src/spikeinterface/widgets/tests/test_widgets.py | 2 -- src/spikeinterface/widgets/unit_depths.py | 2 -- src/spikeinterface/widgets/unit_locations.py | 2 -- 7 files changed, 15 deletions(-) diff --git a/src/spikeinterface/widgets/all_amplitudes_distributions.py b/src/spikeinterface/widgets/all_amplitudes_distributions.py index 280662fd7a..e8b25f6823 100644 --- a/src/spikeinterface/widgets/all_amplitudes_distributions.py +++ b/src/spikeinterface/widgets/all_amplitudes_distributions.py @@ -50,9 +50,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt from .utils_matplotlib import make_mpl_figure - from matplotlib.patches import Ellipse - from matplotlib.lines import Line2D - dp = to_attr(data_plot) self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) diff --git a/src/spikeinterface/widgets/quality_metrics.py b/src/spikeinterface/widgets/quality_metrics.py index 459a32e6f2..4a6b46b72d 100644 --- a/src/spikeinterface/widgets/quality_metrics.py +++ b/src/spikeinterface/widgets/quality_metrics.py @@ -22,8 +22,6 @@ class QualityMetricsWidget(MetricsBaseWidget): For sortingview backend, if True the unit selector is not displayed, default False """ - # possible_backends = {} - def __init__( self, waveform_extractor: WaveformExtractor, diff --git a/src/spikeinterface/widgets/sorting_summary.py b/src/spikeinterface/widgets/sorting_summary.py index 9291de2956..b9760205f9 100644 --- a/src/spikeinterface/widgets/sorting_summary.py +++ b/src/spikeinterface/widgets/sorting_summary.py @@ -34,8 +34,6 @@ class SortingSummaryWidget(BaseWidget): (sortingview backend) """ - # possible_backends = {} - def __init__( self, waveform_extractor: WaveformExtractor, diff --git a/src/spikeinterface/widgets/template_similarity.py b/src/spikeinterface/widgets/template_similarity.py index 69aad70b1f..63ac177835 100644 --- a/src/spikeinterface/widgets/template_similarity.py +++ b/src/spikeinterface/widgets/template_similarity.py @@ -25,8 +25,6 @@ class TemplateSimilarityWidget(BaseWidget): If True, color bar is displayed, default True. """ - # possible_backends = {} - def __init__( self, waveform_extractor: WaveformExtractor, diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 96c6ab80eb..7bf508fe71 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -13,7 +13,6 @@ from spikeinterface import extract_waveforms, load_waveforms, download_dataset, compute_sparsity -# from spikeinterface.widgets import HAVE_MPL, HAVE_SV import spikeinterface.extractors as se @@ -36,7 +35,6 @@ else: cache_folder = Path("cache_folder") / "widgets" -print(cache_folder) ON_GITHUB = bool(os.getenv("GITHUB_ACTIONS")) KACHERY_CLOUD_SET = bool(os.getenv("KACHERY_CLOUD_CLIENT_ID")) and bool(os.getenv("KACHERY_CLOUD_PRIVATE_KEY")) diff --git a/src/spikeinterface/widgets/unit_depths.py b/src/spikeinterface/widgets/unit_depths.py index e48f274962..1aeae254c8 100644 --- a/src/spikeinterface/widgets/unit_depths.py +++ b/src/spikeinterface/widgets/unit_depths.py @@ -24,8 +24,6 @@ class UnitDepthsWidget(BaseWidget): Sign of peak for amplitudes, default 'neg' """ - # possible_backends = {} - def __init__( self, waveform_extractor, unit_colors=None, depth_axis=1, peak_sign="neg", backend=None, **backend_kwargs ): diff --git a/src/spikeinterface/widgets/unit_locations.py b/src/spikeinterface/widgets/unit_locations.py index f8ea042f84..42267e711f 100644 --- a/src/spikeinterface/widgets/unit_locations.py +++ b/src/spikeinterface/widgets/unit_locations.py @@ -33,8 +33,6 @@ class UnitLocationsWidget(BaseWidget): If True, the axis is set to off, default False (matplotlib backend) """ - # possible_backends = {} - def __init__( self, waveform_extractor: WaveformExtractor, From 370dc66ab2f89b51d411b366e9d650443852043d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 21 Jul 2023 12:16:34 +0000 Subject: [PATCH 31/31] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/widgets/tests/test_widgets.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 7bf508fe71..a5f75ebf50 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -14,7 +14,6 @@ from spikeinterface import extract_waveforms, load_waveforms, download_dataset, compute_sparsity - import spikeinterface.extractors as se import spikeinterface.widgets as sw import spikeinterface.comparison as sc