Skip to content

Commit

Permalink
Add widen/narrow button and scale bar to plot_unitwaveforms/templates
Browse files Browse the repository at this point in the history
  • Loading branch information
alejoe91 committed May 14, 2024
1 parent f034b39 commit e6d5cb3
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 24 deletions.
111 changes: 87 additions & 24 deletions src/spikeinterface/widgets/unit_waveforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ class UnitWaveformsWidget(BaseWidget):
displayed per waveform, (matplotlib backend)
scale : float, default: 1
Scale factor for the waveforms/templates (matplotlib backend)
widen_narrow_scale : float, default: 1
Scale factor for the x-axis of the waveforms/templates (matplotlib backend)
axis_equal : bool, default: False
Equal aspect ratio for x and y axis, to visualize the array geometry to scale
lw_waveforms : float, default: 1
Expand All @@ -65,6 +67,8 @@ class UnitWaveformsWidget(BaseWidget):
are used for the lower bounds, and the second half for the upper bounds.
Inner elements produce darker shadings. For sortingview backend only 2 or 4 elements are
supported.
scalebar : bool, default: False
Display a scale bar on the waveforms plot (matplotlib backend)
hide_unit_selector : bool, default: False
For sortingview backend, if True the unit selector is not displayed
same_axis : bool, default: False
Expand All @@ -88,6 +92,7 @@ def __init__(
sparsity=None,
ncols=5,
scale=1,
widen_narrow_scale=1,
lw_waveforms=1,
lw_templates=2,
axis_equal=False,
Expand All @@ -97,6 +102,7 @@ def __init__(
same_axis=False,
shade_templates=True,
templates_percentile_shading=(1, 25, 75, 99),
scalebar=False,
x_offset_units=False,
alpha_waveforms=0.5,
alpha_templates=1,
Expand Down Expand Up @@ -168,10 +174,6 @@ def __init__(
templates_percentile_shading = None
templates_shading = self._get_template_shadings(unit_ids, templates_percentile_shading)

xvectors, y_scale, y_offset, delta_x = get_waveforms_scales(
templates, channel_locations, nbefore, x_offset_units
)

wfs_by_ids = {}
if plot_waveforms:
# this must be a sorting_analyzer
Expand Down Expand Up @@ -204,12 +206,14 @@ def __init__(
plot_data = dict(
sorting_analyzer_or_templates=sorting_analyzer_or_templates,
sampling_frequency=sorting_analyzer_or_templates.sampling_frequency,
nbefore=nbefore,
unit_ids=unit_ids,
channel_ids=channel_ids,
sparsity=sparsity,
unit_colors=unit_colors,
channel_locations=channel_locations,
scale=scale,
widen_narrow_scale=widen_narrow_scale,
templates=templates,
templates_shading=templates_shading,
do_shading=shade_templates,
Expand All @@ -220,19 +224,16 @@ def __init__(
unit_selected_waveforms=unit_selected_waveforms,
axis_equal=axis_equal,
max_spikes_per_unit=max_spikes_per_unit,
xvectors=xvectors,
y_scale=y_scale,
y_offset=y_offset,
wfs_by_ids=wfs_by_ids,
set_title=set_title,
same_axis=same_axis,
scalebar=scalebar,
templates_percentile_shading=templates_percentile_shading,
x_offset_units=x_offset_units,
lw_waveforms=lw_waveforms,
lw_templates=lw_templates,
alpha_waveforms=alpha_waveforms,
alpha_templates=alpha_templates,
delta_x=delta_x,
hide_unit_selector=hide_unit_selector,
plot_legend=plot_legend,
)
Expand All @@ -258,6 +259,10 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):

self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs)

xvectors, y_scale, y_offset, delta_x = get_waveforms_scales(
dp.templates, dp.channel_locations, dp.nbefore, dp.x_offset_units, dp.widen_narrow_scale
)

for i, unit_id in enumerate(dp.unit_ids):
if dp.same_axis:
ax = self.ax
Expand All @@ -266,7 +271,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
color = dp.unit_colors[unit_id]

chan_inds = dp.sparsity.unit_id_to_channel_indices[unit_id]
xvectors_flat = dp.xvectors[:, chan_inds].T.flatten()
xvectors_flat = xvectors[:, chan_inds].T.flatten()

# plot waveforms
if dp.plot_waveforms:
Expand All @@ -278,27 +283,46 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
random_idxs = np.random.permutation(len(wfs))[: dp.max_spikes_per_unit]
wfs = wfs[random_idxs]

wfs = wfs * dp.y_scale + dp.y_offset[None, :, chan_inds]
wfs = wfs * y_scale + y_offset[None, :, chan_inds]
wfs_flat = wfs.swapaxes(1, 2).reshape(wfs.shape[0], -1).T

if dp.x_offset_units:
# 0.7 is to match spacing in xvect
xvec = xvectors_flat + i * 0.7 * dp.delta_x
xvec = xvectors_flat + i * 0.7 * delta_x
else:
xvec = xvectors_flat

ax.plot(xvec, wfs_flat, lw=dp.lw_waveforms, alpha=dp.alpha_waveforms, color=color)

if not dp.plot_templates:
ax.get_lines()[-1].set_label(f"{unit_id}")
if not dp.plot_templates and dp.scalebar and not dp.same_axis:
# xscale
min_wfs = np.min(wfs_flat)
wfs_for_scale = dp.wfs_by_ids[unit_id] * y_scale
offset = 0.1 * (np.max(wfs_flat) - np.min(wfs_flat))
xargmin = np.nanargmin(xvec)
xscale_bar = [xvec[xargmin], xvec[xargmin + dp.nbefore]]
ax.plot(xscale_bar, [min_wfs - offset, min_wfs - offset], color="k")
nbefore_time = int(dp.nbefore / dp.sampling_frequency * 1000)
ax.text(
xscale_bar[0] + xscale_bar[1] // 3, min_wfs - 1.5 * offset, f"{nbefore_time} ms", fontsize=8
)

# yscale
length = int(np.ptp(wfs_flat) // 5)
length_uv = int(np.ptp(wfs_for_scale) // 5)
x_offset = xscale_bar[0] - np.ptp(xscale_bar) // 2
ax.plot([xscale_bar[0], xscale_bar[0]], [min_wfs - offset, min_wfs - offset + length], color="k")
ax.text(x_offset, min_wfs - offset + length // 3, f"{length_uv} $\mu$V", fontsize=8, rotation=90)

# plot template
if dp.plot_templates:
template = dp.templates[i, :, :][:, chan_inds] * dp.scale * dp.y_scale + dp.y_offset[:, chan_inds]
template = dp.templates[i, :, :][:, chan_inds] * dp.scale * y_scale + y_offset[:, chan_inds]

if dp.x_offset_units:
# 0.7 is to match spacing in xvect
xvec = xvectors_flat + i * 0.7 * dp.delta_x
xvec = xvectors_flat + i * 0.7 * delta_x
else:
xvec = xvectors_flat
# plot template shading if waveforms are not plotted
Expand All @@ -310,12 +334,11 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
shading_alphas = np.linspace(lightest_gray_alpha, darkest_gray_alpha, n_shadings)
for s in range(n_shadings):
lower_bound = (
dp.templates_shading[s][i, :, :][:, chan_inds] * dp.scale * dp.y_scale
+ dp.y_offset[:, chan_inds]
dp.templates_shading[s][i, :, :][:, chan_inds] * dp.scale * y_scale + y_offset[:, chan_inds]
)
upper_bound = (
dp.templates_shading[n_percentiles - 1 - s][i, :, :][:, chan_inds] * dp.scale * dp.y_scale
+ dp.y_offset[:, chan_inds]
dp.templates_shading[n_percentiles - 1 - s][i, :, :][:, chan_inds] * dp.scale * y_scale
+ y_offset[:, chan_inds]
)
ax.fill_between(
xvec,
Expand Down Expand Up @@ -345,6 +368,26 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
if dp.set_title:
ax.set_title(f"template {template_label}")

if not dp.plot_waveforms and dp.scalebar and not dp.same_axis:
# xscale
template_for_scale = dp.templates[i, :, :][:, chan_inds] * dp.scale
min_wfs = np.min(template)
offset = 0.1 * (np.max(template) - np.min(template))
xargmin = np.nanargmin(xvec)
xscale_bar = [xvec[xargmin], xvec[xargmin + dp.nbefore]]
ax.plot(xscale_bar, [min_wfs - offset, min_wfs - offset], color="k")
nbefore_time = int(dp.nbefore / dp.sampling_frequency * 1000)
ax.text(
xscale_bar[0] + xscale_bar[1] // 3, min_wfs - 1.5 * offset, f"{nbefore_time} ms", fontsize=8
)

# yscale
length = int(np.ptp(template) // 5)
length_uv = int(np.ptp(template_for_scale) // 5)
x_offset = xscale_bar[0] - np.ptp(xscale_bar) // 2
ax.plot([xscale_bar[0], xscale_bar[0]], [min_wfs - offset, min_wfs - offset + length], color="k")
ax.text(x_offset, min_wfs - offset + length // 3, f"{length_uv} $\mu$V", fontsize=8, rotation=90)

# plot channels
if dp.plot_channels:
# TODO enhance this
Expand All @@ -361,7 +404,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):
import matplotlib.pyplot as plt
import ipywidgets.widgets as widgets
from IPython.display import display
from .utils_ipywidgets import check_ipywidget_backend, UnitSelector, ScaleWidget
from .utils_ipywidgets import check_ipywidget_backend, UnitSelector, ScaleWidget, WidenNarrowWidget

check_ipywidget_backend()

Expand Down Expand Up @@ -393,6 +436,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):
self.unit_selector = UnitSelector(data_plot["unit_ids"], layout=widgets.Layout(height="80%"))
self.unit_selector.value = list(data_plot["unit_ids"])[:1]
self.scaler = ScaleWidget(value=data_plot["scale"], layout=widgets.Layout(height="20%"))
self.widen_narrow = WidenNarrowWidget(value=1.0, layout=widgets.Layout(height="20%"))

self.same_axis_button = widgets.Checkbox(
value=False,
Expand All @@ -417,15 +461,21 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):
description="hide axis",
disabled=False,
)

self.scalebar = widgets.Checkbox(
value=False,
description="scalebar",
disabled=False,
)
if self.sorting_analyzer is not None:
footer_list = [self.same_axis_button, self.template_shading_button, self.hide_axis_button]
footer_list = [self.same_axis_button, self.template_shading_button, self.hide_axis_button, self.scalebar]
else:
footer_list = [self.same_axis_button, self.hide_axis_button]
footer_list = [self.same_axis_button, self.hide_axis_button, self.scalebar]
if data_plot["plot_waveforms"]:
footer_list.append(self.plot_templates_button)

footer = widgets.HBox(footer_list)
left_sidebar = widgets.VBox([self.unit_selector, self.scaler])
left_sidebar = widgets.VBox([self.unit_selector, self.scaler, self.widen_narrow])

self.widget = widgets.AppLayout(
center=self.fig_wf.canvas,
Expand All @@ -440,7 +490,14 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):

self.unit_selector.observe(self._update_plot, names="value", type="change")
self.scaler.observe(self._update_plot, names="value", type="change")
for w in self.same_axis_button, self.plot_templates_button, self.template_shading_button, self.hide_axis_button:
self.widen_narrow.observe(self._update_plot, names="value", type="change")
for w in (
self.same_axis_button,
self.plot_templates_button,
self.template_shading_button,
self.hide_axis_button,
self.scalebar,
):
w.observe(self._update_plot, names="value", type="change")

if backend_kwargs["display"]:
Expand Down Expand Up @@ -502,6 +559,12 @@ def _update_plot(self, change):
data_plot["plot_templates"] = plot_templates
data_plot["do_shading"] = do_shading
data_plot["scale"] = self.scaler.value
data_plot["widen_narrow_scale"] = self.widen_narrow.value

if same_axis:
self.scalebar.value = False
data_plot["scalebar"] = self.scalebar.value

if data_plot["plot_waveforms"]:
wf_ext = self.sorting_analyzer.get_extension("waveforms")
data_plot["wfs_by_ids"] = {
Expand Down Expand Up @@ -554,7 +617,7 @@ def _update_plot(self, change):
fig_probe.canvas.flush_events()


def get_waveforms_scales(templates, channel_locations, nbefore, x_offset_units=False):
def get_waveforms_scales(templates, channel_locations, nbefore, x_offset_units=False, widen_narrow_scale=1.0):
"""
Return scales and x_vector for templates plotting
"""
Expand Down Expand Up @@ -582,7 +645,7 @@ def get_waveforms_scales(templates, channel_locations, nbefore, x_offset_units=F

nsamples = templates.shape[1]

xvect = delta_x * (np.arange(nsamples) - nbefore) / nsamples * 0.7
xvect = (delta_x * widen_narrow_scale) * (np.arange(nsamples) - nbefore) / nsamples * 0.7

if x_offset_units:
ch_locs = channel_locations
Expand Down
55 changes: 55 additions & 0 deletions src/spikeinterface/widgets/utils_ipywidgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,61 @@ def value_changed(self, change=None):
self.update_label()


class WidenNarrowWidget(W.VBox):
value = traitlets.Float()

def __init__(self, value=1.0, factor=1.2, **kwargs):
assert factor > 1.0
self.factor = factor

self.scale_label = W.Label("Widen/Narrow", layout=W.Layout(width="95%", justify_content="center"))

self.right_selector = W.Button(
description="",
disabled=False,
button_style="", # 'success', 'info', 'warning', 'danger' or ''
tooltip="Increase horizontal scale",
icon="arrow-right",
# layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"),
layout=W.Layout(width="60%", align_self="center"),
)

self.left_selector = W.Button(
description="",
disabled=False,
button_style="", # 'success', 'info', 'warning', 'danger' or ''
tooltip="Decrease horizontal scale",
icon="arrow-left",
# layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"),
layout=W.Layout(width="60%", align_self="center"),
)

self.right_selector.on_click(self.left_clicked)
self.left_selector.on_click(self.right_clicked)

self.value = value
super(W.VBox, self).__init__(
children=[self.scale_label, W.HBox([self.left_selector, self.right_selector])],
# layout=W.Layout(align_items="center", width="100%", height="100%"),
**kwargs,
)

# self.update_label()
# self.observe(self.value_changed, names=["value"], type="change")

def update_label(self):
self.scale_label.value = f"Scale: {self.value:0.2f}"

def left_clicked(self, change=None):
self.value = self.value / self.factor

def right_clicked(self, change=None):
self.value = self.value * self.factor

def value_changed(self, change=None):
self.update_label()


class UnitSelector(W.VBox):
value = traitlets.List()

Expand Down

0 comments on commit e6d5cb3

Please sign in to comment.