Skip to content

Commit

Permalink
Merge pull request #2035 from samuelgarcia/improve_plot_traces_ipywid…
Browse files Browse the repository at this point in the history
…gets

Improve plot traces ipywidgets
  • Loading branch information
samuelgarcia authored Sep 29, 2023
2 parents c8be1a0 + 2c015f7 commit 89affa7
Show file tree
Hide file tree
Showing 9 changed files with 605 additions and 383 deletions.
78 changes: 43 additions & 35 deletions src/spikeinterface/widgets/amplitudes.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,13 +147,16 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
else:
bins = dp.bins
ax_hist = self.axes.flatten()[1]
ax_hist.hist(amps, bins=bins, orientation="horizontal", color=dp.unit_colors[unit_id], alpha=0.8)
# this is super slow, using plot and np.histogram is really much faster (and nicer!)
# ax_hist.hist(amps, bins=bins, orientation="horizontal", color=dp.unit_colors[unit_id], alpha=0.8)
count, bins = np.histogram(amps, bins=bins)
ax_hist.plot(count, bins[:-1], color=dp.unit_colors[unit_id], alpha=0.8)

if dp.plot_histograms:
ax_hist = self.axes.flatten()[1]
ax_hist.set_ylim(scatter_ax.get_ylim())
ax_hist.axis("off")
self.figure.tight_layout()
# self.figure.tight_layout()

if dp.plot_legend:
if hasattr(self, "legend") and self.legend is not None:
Expand All @@ -171,9 +174,11 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):

def plot_ipywidgets(self, data_plot, **backend_kwargs):
import matplotlib.pyplot as plt
import ipywidgets.widgets as widgets

# import ipywidgets.widgets as widgets
import ipywidgets.widgets as W
from IPython.display import display
from .utils_ipywidgets import check_ipywidget_backend, make_unit_controller
from .utils_ipywidgets import check_ipywidget_backend, UnitSelector

check_ipywidget_backend()

Expand All @@ -188,60 +193,63 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):
ratios = [0.15, 0.85]

with plt.ioff():
output = widgets.Output()
output = W.Output()
with output:
self.figure = plt.figure(figsize=((ratios[1] * width_cm) * cm, height_cm * cm))
plt.show()

data_plot["unit_ids"] = data_plot["unit_ids"][:1]
unit_widget, unit_controller = make_unit_controller(
data_plot["unit_ids"], we.unit_ids, ratios[0] * width_cm, height_cm
)
self.unit_selector = UnitSelector(we.unit_ids)
self.unit_selector.value = list(we.unit_ids)[:1]

plot_histograms = widgets.Checkbox(
self.checkbox_histograms = W.Checkbox(
value=data_plot["plot_histograms"],
description="plot histograms",
disabled=False,
description="hist",
)

footer = plot_histograms

self.controller = {"plot_histograms": plot_histograms}
self.controller.update(unit_controller)

for w in self.controller.values():
w.observe(self._update_ipywidget)
left_sidebar = W.VBox(
children=[
self.unit_selector,
self.checkbox_histograms,
],
layout=W.Layout(align_items="center", width="4cm", height="100%"),
)

self.widget = widgets.AppLayout(
self.widget = W.AppLayout(
center=self.figure.canvas,
left_sidebar=unit_widget,
left_sidebar=left_sidebar,
pane_widths=ratios + [0],
footer=footer,
)

# a first update
self._update_ipywidget(None)
self._full_update_plot()

self.unit_selector.observe(self._update_plot, names="value", type="change")
self.checkbox_histograms.observe(self._full_update_plot, names="value", type="change")

if backend_kwargs["display"]:
display(self.widget)

def _update_ipywidget(self, change):
def _full_update_plot(self, change=None):
self.figure.clear()
data_plot = self.next_data_plot
data_plot["unit_ids"] = self.unit_selector.value
data_plot["plot_histograms"] = self.checkbox_histograms.value
data_plot["plot_legend"] = False

unit_ids = self.controller["unit_ids"].value
plot_histograms = self.controller["plot_histograms"].value
backend_kwargs = dict(figure=self.figure, axes=None, ax=None)
self.plot_matplotlib(data_plot, **backend_kwargs)
self._update_plot()

# matplotlib next_data_plot dict update at each call
data_plot = self.next_data_plot
data_plot["unit_ids"] = unit_ids
data_plot["plot_histograms"] = plot_histograms
def _update_plot(self, change=None):
for ax in self.axes.flatten():
ax.clear()

backend_kwargs = {}
# backend_kwargs["figure"] = self.fig
backend_kwargs["figure"] = self.figure
backend_kwargs["axes"] = None
backend_kwargs["ax"] = None
data_plot = self.next_data_plot
data_plot["unit_ids"] = self.unit_selector.value
data_plot["plot_histograms"] = self.checkbox_histograms.value
data_plot["plot_legend"] = False

backend_kwargs = dict(figure=None, axes=self.axes, ax=None)
self.plot_matplotlib(data_plot, **backend_kwargs)

self.figure.canvas.draw()
Expand Down
3 changes: 2 additions & 1 deletion src/spikeinterface/widgets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,15 @@ def set_default_plotter_backend(backend):
"width_cm": "Width of the figure in cm (default 10)",
"height_cm": "Height of the figure in cm (default 6)",
"display": "If True, widgets are immediately displayed",
# "controllers": ""
},
"ephyviewer": {},
}

default_backend_kwargs = {
"matplotlib": {"figure": None, "ax": None, "axes": None, "ncols": 5, "figsize": None, "figtitle": None},
"sortingview": {"generate_url": True, "display": True, "figlabel": None, "height": None},
"ipywidgets": {"width_cm": 25, "height_cm": 10, "display": True},
"ipywidgets": {"width_cm": 25, "height_cm": 10, "display": True, "controllers": None},
"ephyviewer": {},
}

Expand Down
21 changes: 8 additions & 13 deletions src/spikeinterface/widgets/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):
import matplotlib.pyplot as plt
import ipywidgets.widgets as widgets
from IPython.display import display
from .utils_ipywidgets import check_ipywidget_backend, make_unit_controller
from .utils_ipywidgets import check_ipywidget_backend, UnitSelector

check_ipywidget_backend()

Expand All @@ -147,34 +147,28 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):
with output:
self.figure = plt.figure(figsize=((ratios[1] * width_cm) * cm, height_cm * cm))
plt.show()
if data_plot["unit_ids"] is None:
data_plot["unit_ids"] = []

unit_widget, unit_controller = make_unit_controller(
data_plot["unit_ids"], list(data_plot["unit_colors"].keys()), ratios[0] * width_cm, height_cm
)

self.controller = unit_controller

for w in self.controller.values():
w.observe(self._update_ipywidget)
self.unit_selector = UnitSelector(data_plot["sorting"].unit_ids)
self.unit_selector.value = []

self.widget = widgets.AppLayout(
center=self.figure.canvas,
left_sidebar=unit_widget,
left_sidebar=self.unit_selector,
pane_widths=ratios + [0],
)

# a first update
self._update_ipywidget(None)

self.unit_selector.observe(self._update_ipywidget, names="value", type="change")

if backend_kwargs["display"]:
display(self.widget)

def _update_ipywidget(self, change):
from matplotlib.lines import Line2D

unit_ids = self.controller["unit_ids"].value
unit_ids = self.unit_selector.value

unit_colors = self.data_plot["unit_colors"]
# matplotlib next_data_plot dict update at each call
Expand All @@ -198,6 +192,7 @@ def _update_ipywidget(self, change):
self.plot_matplotlib(self.data_plot, **backend_kwargs)

if len(unit_ids) > 0:
# TODO later make option to control legend or not
for l in self.figure.legends:
l.remove()
handles = [
Expand Down
34 changes: 11 additions & 23 deletions src/spikeinterface/widgets/spike_locations.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):
import matplotlib.pyplot as plt
import ipywidgets.widgets as widgets
from IPython.display import display
from .utils_ipywidgets import check_ipywidget_backend, make_unit_controller
from .utils_ipywidgets import check_ipywidget_backend, UnitSelector

check_ipywidget_backend()

Expand All @@ -210,48 +210,36 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):
fig, self.ax = plt.subplots(figsize=((ratios[1] * width_cm) * cm, height_cm * cm))
plt.show()

data_plot["unit_ids"] = data_plot["unit_ids"][:1]

unit_widget, unit_controller = make_unit_controller(
data_plot["unit_ids"],
list(data_plot["unit_colors"].keys()),
ratios[0] * width_cm,
height_cm,
)

self.controller = unit_controller

for w in self.controller.values():
w.observe(self._update_ipywidget)
self.unit_selector = UnitSelector(data_plot["unit_ids"])
self.unit_selector.value = list(data_plot["unit_ids"])[:1]

self.widget = widgets.AppLayout(
center=fig.canvas,
left_sidebar=unit_widget,
left_sidebar=self.unit_selector,
pane_widths=ratios + [0],
)

# a first update
self._update_ipywidget(None)
self._update_ipywidget()

self.unit_selector.observe(self._update_ipywidget, names="value", type="change")

if backend_kwargs["display"]:
display(self.widget)

def _update_ipywidget(self, change):
def _update_ipywidget(self, change=None):
self.ax.clear()

unit_ids = self.controller["unit_ids"].value

# matplotlib next_data_plot dict update at each call
data_plot = self.next_data_plot
data_plot["unit_ids"] = unit_ids
data_plot["unit_ids"] = self.unit_selector.value
data_plot["plot_all_units"] = True
# TODO add an option checkbox for legend
data_plot["plot_legend"] = True
data_plot["hide_axis"] = True

backend_kwargs = {}
backend_kwargs["ax"] = self.ax
backend_kwargs = dict(ax=self.ax)

# self.mpl_plotter.do_plot(data_plot, **backend_kwargs)
self.plot_matplotlib(data_plot, **backend_kwargs)
fig = self.ax.get_figure()
fig.canvas.draw()
Expand Down
Loading

0 comments on commit 89affa7

Please sign in to comment.