Skip to content

Commit

Permalink
Merge pull request #3003 from alejoe91/plot-potential-merges
Browse files Browse the repository at this point in the history
Plot potential merges
  • Loading branch information
alejoe91 authored Jun 14, 2024
2 parents 17ee785 + bdc7fe4 commit bd89c99
Show file tree
Hide file tree
Showing 6 changed files with 283 additions and 29 deletions.
8 changes: 4 additions & 4 deletions src/spikeinterface/widgets/amplitudes.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):
self.next_data_plot = data_plot.copy()

cm = 1 / 2.54
we = data_plot["sorting_analyzer"]
analyzer = data_plot["sorting_analyzer"]

width_cm = backend_kwargs["width_cm"]
height_cm = backend_kwargs["height_cm"]
Expand All @@ -202,8 +202,8 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):
self.figure = plt.figure(figsize=((ratios[1] * width_cm) * cm, height_cm * cm))
plt.show()

self.unit_selector = UnitSelector(we.unit_ids)
self.unit_selector.value = list(we.unit_ids)[:1]
self.unit_selector = UnitSelector(analyzer.unit_ids)
self.unit_selector.value = list(analyzer.unit_ids)[:1]

self.checkbox_histograms = W.Checkbox(
value=data_plot["plot_histograms"],
Expand All @@ -215,7 +215,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):
self.unit_selector,
self.checkbox_histograms,
],
layout=W.Layout(align_items="center", width="4cm", height="100%"),
layout=W.Layout(align_items="center", width="100%", height="100%"),
)

self.widget = W.AppLayout(
Expand Down
248 changes: 248 additions & 0 deletions src/spikeinterface/widgets/potential_merges.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
from __future__ import annotations

import numpy as np
from warnings import warn

from .base import BaseWidget, default_backend_kwargs

from .amplitudes import AmplitudesWidget
from .crosscorrelograms import CrossCorrelogramsWidget
from .unit_templates import UnitTemplatesWidget

from .utils import get_some_colors

from ..core.sortinganalyzer import SortingAnalyzer


class PotentialMergesWidget(BaseWidget):
"""
Plots potential merges
Parameters
----------
sorting_analyzer : SortingAnalyzer
The input sorting analyzer
potential_merges : list of lists or tuples
List of potential merges (see `spikeinterface.curation.get_potential_auto_merges`)
segment_index : int
The segment index to display
max_spike_samples : int or None, default: None
The maximum number of spikes to display per unit
"""

def __init__(
self,
sorting_analyzer: SortingAnalyzer,
potential_merges: list,
unit_colors: list = None,
segment_index: int = 0,
max_spikes_per_unit: int = 100,
backend=None,
**backend_kwargs,
):
sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer)

self.check_extensions(sorting_analyzer, ["templates", "spike_amplitudes", "correlograms"])

unique_merge_units = np.unique([u for merge in potential_merges for u in merge])
if unit_colors is None:
unit_colors = get_some_colors(sorting_analyzer.unit_ids)

plot_data = dict(
sorting_analyzer=sorting_analyzer,
potential_merges=potential_merges,
unit_colors=unit_colors,
segment_index=segment_index,
max_spikes_per_unit=max_spikes_per_unit,
unique_merge_units=unique_merge_units,
)

BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs)

def plot_ipywidgets(self, data_plot, **backend_kwargs):
from math import lcm
import matplotlib.pyplot as plt

# import ipywidgets.widgets as widgets
import ipywidgets.widgets as W
from IPython.display import display
from .utils_ipywidgets import check_ipywidget_backend, ScaleWidget, WidenNarrowWidget

check_ipywidget_backend()

self.next_data_plot = data_plot.copy()

cm = 1 / 2.54
analyzer = data_plot["sorting_analyzer"]

width_cm = backend_kwargs["width_cm"]
height_cm = backend_kwargs["height_cm"] * 3

ratios = [0.2, 0.8]

with plt.ioff():
output = W.Output()
with output:
self.figure = plt.figure(
figsize=((ratios[1] * width_cm) * cm, height_cm * cm),
constrained_layout=True,
)
plt.show()
# find max number of merges:
self.gs = None
self.axes_amplitudes = None
self.ax_templates = None
self.ax_probe = None
self.axes_cc = None

# Instantiate sub-widgets
self.w_amplitudes = AmplitudesWidget(
analyzer,
unit_colors=data_plot["unit_colors"],
unit_ids=data_plot["unique_merge_units"],
plot_histograms=True,
plot_legend=False,
immediate_plot=False,
)
self.w_templates = UnitTemplatesWidget(
analyzer,
unit_ids=data_plot["unique_merge_units"],
unit_colors=data_plot["unit_colors"],
plot_legend=False,
immediate_plot=False,
)
self.w_crosscorrelograms = CrossCorrelogramsWidget(
analyzer,
unit_ids=data_plot["unique_merge_units"],
min_similarity_for_correlograms=0,
unit_colors=data_plot["unit_colors"],
immediate_plot=False,
)

options = ["-".join([str(u) for u in m]) for m in data_plot["potential_merges"]]
value = options[0]
self.unit_selector_label = W.Label(value="Potential merges:")
self.unit_selector = W.Dropdown(options=options, value=value, layout=W.Layout(width="80%"))
self.previous_num_merges = len(data_plot["potential_merges"][0])
self.scaler = ScaleWidget(value=1.0, layout=W.Layout(width="80%"))
self.widen_narrow = WidenNarrowWidget(value=1.0, layout=W.Layout(width="80%"))

left_sidebar = W.VBox(
[self.unit_selector_label, self.unit_selector, self.scaler, self.widen_narrow],
layout=W.Layout(width="100%"),
)

self.widget = W.AppLayout(
center=self.figure.canvas,
left_sidebar=left_sidebar,
pane_widths=ratios + [0],
)

if len(np.unique([len(m) for m in self.data_plot["potential_merges"]])) == 1:
# in this case we multiply the number of columns by 3 to have 2/3 of the space for the templates
ncols = 3 * len(self.data_plot["potential_merges"][0])
else:
ncols = lcm(*[len(m) for m in self.data_plot["potential_merges"]])
right_axes = int(ncols * 2 / 3)
self.ncols = ncols
self.right_axes = right_axes

# a first update
self._update_plot(None)

self.unit_selector.observe(self._update_plot, names="value", type="change")
self.scaler.observe(self._update_plot, names="value", type="change")
self.widen_narrow.observe(self._update_plot, names="value", type="change")

if backend_kwargs["display"]:
display(self.widget)

def _update_gs(self, merge_units):
import matplotlib.gridspec as gridspec

# we create a vertical grid with 1 row between the 3 first plots
n_units = len(merge_units)
ncols = self.ncols
right_axes = self.right_axes
unit_len_in_gs = self.ncols // n_units
nrows = ncols * 3 + 2

if self.gs is not None and self.previous_num_merges == len(merge_units):
self.ax_templates.clear()
self.ax_probe.clear()
for ax in self.axes_amplitudes:
ax.clear()
for ax in self.axes_cc.flatten():
ax.clear()
else:
self.figure.clear()
self.gs = gridspec.GridSpec(nrows, ncols, figure=self.figure)
self.ax_templates = self.figure.add_subplot(self.gs[:ncols, :right_axes])
self.ax_probe = self.figure.add_subplot(self.gs[:ncols, right_axes:])
row_offset = ncols + 1
ax_amplitudes_ts = self.figure.add_subplot(self.gs[row_offset : row_offset + ncols, :right_axes])
ax_amplitudes_hist = self.figure.add_subplot(self.gs[row_offset : row_offset + ncols, right_axes:])
self.axes_amplitudes = [ax_amplitudes_ts, ax_amplitudes_hist]
row_offset += ncols + 1
self.axes_cc = []
for i in range(0, n_units):
for j in range(0, n_units):
self.axes_cc.append(
self.figure.add_subplot(
self.gs[
row_offset + (unit_len_in_gs) * i : row_offset + (unit_len_in_gs) * (i + 1),
j * unit_len_in_gs : (j + 1) * unit_len_in_gs,
]
)
)
self.axes_cc = np.array(self.axes_cc).reshape((n_units, n_units))
self.previous_num_merges = len(merge_units)

def _update_plot(self, change=None):

merge_units = self.unit_selector.value
sorting_analyzer = self.data_plot["sorting_analyzer"]
channel_locations = sorting_analyzer.get_channel_locations()
unit_ids = sorting_analyzer.unit_ids

# unroll the merges
unit_ids_str = [str(u) for u in unit_ids]
plot_unit_ids = []
for m in merge_units.split("-"):
plot_unit_ids.append(unit_ids[unit_ids_str.index(m)])
self._update_gs(plot_unit_ids)

backend_kwargs_mpl = default_backend_kwargs["matplotlib"].copy()
backend_kwargs_mpl.pop("axes")
backend_kwargs_mpl.pop("ax")

amplitude_data_plot = self.w_amplitudes.data_plot.copy()
amplitude_data_plot["unit_ids"] = plot_unit_ids
self.w_amplitudes.plot_matplotlib(amplitude_data_plot, ax=None, axes=self.axes_amplitudes, **backend_kwargs_mpl)

unit_template_data_plot = self.w_templates.data_plot.copy()
unit_template_data_plot["unit_ids"] = plot_unit_ids
unit_template_data_plot["same_axis"] = True
unit_template_data_plot["set_title"] = False
unit_template_data_plot["scale"] = self.scaler.value
unit_template_data_plot["widen_narrow_scale"] = self.widen_narrow.value
# update templates and shading
templates_ext = sorting_analyzer.get_extension("templates")
unit_template_data_plot["templates"] = templates_ext.get_templates(unit_ids=plot_unit_ids, operator="average")
unit_template_data_plot["templates_shading"] = self.w_templates._get_template_shadings(
plot_unit_ids, self.w_templates.data_plot["templates_percentile_shading"]
)
self.w_templates.plot_matplotlib(unit_template_data_plot, ax=self.ax_templates, axes=None, **backend_kwargs_mpl)
self.ax_templates.axis("off")
self.w_templates._plot_probe(self.ax_probe, channel_locations, plot_unit_ids)
crosscorrelograms_data_plot = self.w_crosscorrelograms.data_plot.copy()
crosscorrelograms_data_plot["unit_ids"] = plot_unit_ids
merge_unit_indices = np.flatnonzero(np.isin(self.unique_merge_units, plot_unit_ids))
updated_correlograms = crosscorrelograms_data_plot["correlograms"]
updated_correlograms = updated_correlograms[merge_unit_indices][:, merge_unit_indices]
crosscorrelograms_data_plot["correlograms"] = updated_correlograms
self.w_crosscorrelograms.plot_matplotlib(
crosscorrelograms_data_plot, axes=self.axes_cc, ax=None, **backend_kwargs_mpl
)
self.figure.canvas.draw()
self.figure.canvas.flush_events()
43 changes: 25 additions & 18 deletions src/spikeinterface/widgets/unit_waveforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
else:
if dp.same_axis:
backend_kwargs["num_axes"] = 1
backend_kwargs["ncols"] = None
backend_kwargs["ncols"] = 1
else:
backend_kwargs["num_axes"] = len(dp.unit_ids)
backend_kwargs["ncols"] = min(dp.ncols, len(dp.unit_ids))
Expand Down Expand Up @@ -487,11 +487,10 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):

# a first update
self._update_plot(None)

self.unit_selector.observe(self._update_plot, names="value", type="change")
self.scaler.observe(self._update_plot, names="value", type="change")
self.widen_narrow.observe(self._update_plot, names="value", type="change")
for w in (
self.unit_selector,
self.scaler,
self.widen_narrow,
self.same_axis_button,
self.plot_templates_button,
self.template_shading_button,
Expand Down Expand Up @@ -592,30 +591,38 @@ def _update_plot(self, change):
ax.axis("off")

# update probe plot
self.ax_probe.plot(
self._plot_probe(
self.ax_probe,
channel_locations,
unit_ids,
)
fig_probe = self.ax_probe.get_figure()

self.fig_wf.canvas.draw()
self.fig_wf.canvas.flush_events()
fig_probe.canvas.draw()
fig_probe.canvas.flush_events()

def _plot_probe(self, ax, channel_locations, unit_ids):
# update probe plot
ax.plot(
channel_locations[:, 0], channel_locations[:, 1], ls="", marker="o", color="gray", markersize=2, alpha=0.5
)
self.ax_probe.axis("off")
self.ax_probe.axis("equal")
ax.axis("off")
ax.axis("equal")

# TODO this could be done with probeinterface plotting plotting tools!!
for unit in unit_ids:
channel_inds = data_plot["sparsity"].unit_id_to_channel_indices[unit]
self.ax_probe.plot(
channel_inds = self.data_plot["sparsity"].unit_id_to_channel_indices[unit]
ax.plot(
channel_locations[channel_inds, 0],
channel_locations[channel_inds, 1],
ls="",
marker="o",
markersize=3,
color=self.next_data_plot["unit_colors"][unit],
color=self.data_plot["unit_colors"][unit],
)
self.ax_probe.set_xlim(np.min(channel_locations[:, 0]) - 10, np.max(channel_locations[:, 0]) + 10)
fig_probe = self.ax_probe.get_figure()

self.fig_wf.canvas.draw()
self.fig_wf.canvas.flush_events()
fig_probe.canvas.draw()
fig_probe.canvas.flush_events()
ax.set_xlim(np.min(channel_locations[:, 0]) - 10, np.max(channel_locations[:, 0]) + 10)


def get_waveforms_scales(templates, channel_locations, nbefore, x_offset_units=False, widen_narrow_scale=1.0):
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/widgets/utils_ipywidgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def __init__(self, unit_ids, **kwargs):
options=self.unit_ids,
value=self.unit_ids,
disabled=False,
layout=W.Layout(height="100%", width="80%", align="center"),
layout=W.Layout(height="100%", width="3cm", align="center"),
)

super(W.VBox, self).__init__(children=[label, self.selector], **kwargs)
Expand Down
8 changes: 2 additions & 6 deletions src/spikeinterface/widgets/utils_matplotlib.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import matplotlib
import matplotlib.pyplot as plt
import numpy as np

Expand All @@ -12,13 +11,10 @@ def make_mpl_figure(figure=None, ax=None, axes=None, ncols=None, num_axes=None,
if figure is not None:
assert ax is None and axes is None, "figure/ax/axes : only one of then can be not None"
if num_axes is None:
if "ipympl" not in matplotlib.get_backend():
ax = figure.add_subplot(111)
else:
ax = figure.add_subplot(111)
ax = figure.add_subplot(111)
axes = np.array([[ax]])
else:
assert ncols is not None
assert ncols is not None, "ncols must be provided when num_axes is provided"
axes = []
nrows = int(np.ceil(num_axes / ncols))
axes = np.full((nrows, ncols), fill_value=None, dtype=object)
Expand Down
Loading

0 comments on commit bd89c99

Please sign in to comment.