Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve ipywidgets backend #2035

Merged
78 changes: 43 additions & 35 deletions src/spikeinterface/widgets/amplitudes.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,13 +147,16 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
else:
bins = dp.bins
ax_hist = self.axes.flatten()[1]
ax_hist.hist(amps, bins=bins, orientation="horizontal", color=dp.unit_colors[unit_id], alpha=0.8)
# this is super slow, using plot and np.histogram is really much faster (and nicer!)
# ax_hist.hist(amps, bins=bins, orientation="horizontal", color=dp.unit_colors[unit_id], alpha=0.8)
count, bins = np.histogram(amps, bins=bins)
ax_hist.plot(count, bins[:-1], color=dp.unit_colors[unit_id], alpha=0.8)

if dp.plot_histograms:
ax_hist = self.axes.flatten()[1]
ax_hist.set_ylim(scatter_ax.get_ylim())
ax_hist.axis("off")
self.figure.tight_layout()
# self.figure.tight_layout()

if dp.plot_legend:
if hasattr(self, "legend") and self.legend is not None:
Expand All @@ -171,9 +174,11 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):

def plot_ipywidgets(self, data_plot, **backend_kwargs):
import matplotlib.pyplot as plt
import ipywidgets.widgets as widgets

# import ipywidgets.widgets as widgets
import ipywidgets.widgets as W
from IPython.display import display
from .utils_ipywidgets import check_ipywidget_backend, make_unit_controller
from .utils_ipywidgets import check_ipywidget_backend, UnitSelector

check_ipywidget_backend()

Expand All @@ -188,60 +193,63 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):
ratios = [0.15, 0.85]

with plt.ioff():
output = widgets.Output()
output = W.Output()
with output:
self.figure = plt.figure(figsize=((ratios[1] * width_cm) * cm, height_cm * cm))
plt.show()

data_plot["unit_ids"] = data_plot["unit_ids"][:1]
unit_widget, unit_controller = make_unit_controller(
data_plot["unit_ids"], we.unit_ids, ratios[0] * width_cm, height_cm
)
self.unit_selector = UnitSelector(we.unit_ids)
self.unit_selector.value = list(we.unit_ids)[:1]

plot_histograms = widgets.Checkbox(
self.checkbox_histograms = W.Checkbox(
value=data_plot["plot_histograms"],
description="plot histograms",
disabled=False,
description="hist",
)

footer = plot_histograms

self.controller = {"plot_histograms": plot_histograms}
self.controller.update(unit_controller)

for w in self.controller.values():
w.observe(self._update_ipywidget)
left_sidebar = W.VBox(
children=[
self.unit_selector,
self.checkbox_histograms,
],
layout=W.Layout(align_items="center", width="4cm", height="100%"),
)

self.widget = widgets.AppLayout(
self.widget = W.AppLayout(
center=self.figure.canvas,
left_sidebar=unit_widget,
left_sidebar=left_sidebar,
pane_widths=ratios + [0],
footer=footer,
)

# a first update
self._update_ipywidget(None)
self._full_update_plot()

self.unit_selector.observe(self._update_plot, names="value", type="change")
self.checkbox_histograms.observe(self._full_update_plot, names="value", type="change")

if backend_kwargs["display"]:
display(self.widget)

def _update_ipywidget(self, change):
def _full_update_plot(self, change=None):
self.figure.clear()
data_plot = self.next_data_plot
data_plot["unit_ids"] = self.unit_selector.value
data_plot["plot_histograms"] = self.checkbox_histograms.value
data_plot["plot_legend"] = False

unit_ids = self.controller["unit_ids"].value
plot_histograms = self.controller["plot_histograms"].value
backend_kwargs = dict(figure=self.figure, axes=None, ax=None)
self.plot_matplotlib(data_plot, **backend_kwargs)
self._update_plot()

# matplotlib next_data_plot dict update at each call
data_plot = self.next_data_plot
data_plot["unit_ids"] = unit_ids
data_plot["plot_histograms"] = plot_histograms
def _update_plot(self, change=None):
for ax in self.axes.flatten():
ax.clear()

backend_kwargs = {}
# backend_kwargs["figure"] = self.fig
backend_kwargs["figure"] = self.figure
backend_kwargs["axes"] = None
backend_kwargs["ax"] = None
data_plot = self.next_data_plot
data_plot["unit_ids"] = self.unit_selector.value
data_plot["plot_histograms"] = self.checkbox_histograms.value
data_plot["plot_legend"] = False

backend_kwargs = dict(figure=None, axes=self.axes, ax=None)
self.plot_matplotlib(data_plot, **backend_kwargs)

self.figure.canvas.draw()
Expand Down
3 changes: 2 additions & 1 deletion src/spikeinterface/widgets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,15 @@ def set_default_plotter_backend(backend):
"width_cm": "Width of the figure in cm (default 10)",
"height_cm": "Height of the figure in cm (default 6)",
"display": "If True, widgets are immediately displayed",
# "controllers": ""
},
"ephyviewer": {},
}

default_backend_kwargs = {
"matplotlib": {"figure": None, "ax": None, "axes": None, "ncols": 5, "figsize": None, "figtitle": None},
"sortingview": {"generate_url": True, "display": True, "figlabel": None, "height": None},
"ipywidgets": {"width_cm": 25, "height_cm": 10, "display": True},
"ipywidgets": {"width_cm": 25, "height_cm": 10, "display": True, "controllers": None},
"ephyviewer": {},
}

Expand Down
21 changes: 8 additions & 13 deletions src/spikeinterface/widgets/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):
import matplotlib.pyplot as plt
import ipywidgets.widgets as widgets
from IPython.display import display
from .utils_ipywidgets import check_ipywidget_backend, make_unit_controller
from .utils_ipywidgets import check_ipywidget_backend, UnitSelector

check_ipywidget_backend()

Expand All @@ -147,34 +147,28 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):
with output:
self.figure = plt.figure(figsize=((ratios[1] * width_cm) * cm, height_cm * cm))
plt.show()
if data_plot["unit_ids"] is None:
data_plot["unit_ids"] = []

unit_widget, unit_controller = make_unit_controller(
data_plot["unit_ids"], list(data_plot["unit_colors"].keys()), ratios[0] * width_cm, height_cm
)

self.controller = unit_controller

for w in self.controller.values():
w.observe(self._update_ipywidget)
self.unit_selector = UnitSelector(data_plot["sorting"].unit_ids)
self.unit_selector.value = []

self.widget = widgets.AppLayout(
center=self.figure.canvas,
left_sidebar=unit_widget,
left_sidebar=self.unit_selector,
pane_widths=ratios + [0],
)

# a first update
self._update_ipywidget(None)

self.unit_selector.observe(self._update_ipywidget, names="value", type="change")

if backend_kwargs["display"]:
display(self.widget)

def _update_ipywidget(self, change):
from matplotlib.lines import Line2D

unit_ids = self.controller["unit_ids"].value
unit_ids = self.unit_selector.value

unit_colors = self.data_plot["unit_colors"]
# matplotlib next_data_plot dict update at each call
Expand All @@ -198,6 +192,7 @@ def _update_ipywidget(self, change):
self.plot_matplotlib(self.data_plot, **backend_kwargs)

if len(unit_ids) > 0:
# TODO later make option to control legend or not
for l in self.figure.legends:
l.remove()
handles = [
Expand Down
34 changes: 11 additions & 23 deletions src/spikeinterface/widgets/spike_locations.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):
import matplotlib.pyplot as plt
import ipywidgets.widgets as widgets
from IPython.display import display
from .utils_ipywidgets import check_ipywidget_backend, make_unit_controller
from .utils_ipywidgets import check_ipywidget_backend, UnitSelector

check_ipywidget_backend()

Expand All @@ -210,48 +210,36 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):
fig, self.ax = plt.subplots(figsize=((ratios[1] * width_cm) * cm, height_cm * cm))
plt.show()

data_plot["unit_ids"] = data_plot["unit_ids"][:1]

unit_widget, unit_controller = make_unit_controller(
data_plot["unit_ids"],
list(data_plot["unit_colors"].keys()),
ratios[0] * width_cm,
height_cm,
)

self.controller = unit_controller

for w in self.controller.values():
w.observe(self._update_ipywidget)
self.unit_selector = UnitSelector(data_plot["unit_ids"])
self.unit_selector.value = list(data_plot["unit_ids"])[:1]

self.widget = widgets.AppLayout(
center=fig.canvas,
left_sidebar=unit_widget,
left_sidebar=self.unit_selector,
pane_widths=ratios + [0],
)

# a first update
self._update_ipywidget(None)
self._update_ipywidget()

self.unit_selector.observe(self._update_ipywidget, names="value", type="change")

if backend_kwargs["display"]:
display(self.widget)

def _update_ipywidget(self, change):
def _update_ipywidget(self, change=None):
self.ax.clear()

unit_ids = self.controller["unit_ids"].value

# matplotlib next_data_plot dict update at each call
data_plot = self.next_data_plot
data_plot["unit_ids"] = unit_ids
data_plot["unit_ids"] = self.unit_selector.value
data_plot["plot_all_units"] = True
# TODO add an option checkbox for legend
data_plot["plot_legend"] = True
data_plot["hide_axis"] = True

backend_kwargs = {}
backend_kwargs["ax"] = self.ax
backend_kwargs = dict(ax=self.ax)

# self.mpl_plotter.do_plot(data_plot, **backend_kwargs)
self.plot_matplotlib(data_plot, **backend_kwargs)
fig = self.ax.get_figure()
fig.canvas.draw()
Expand Down
Loading