Skip to content

Commit

Permalink
Merge pull request #3242 from florian6973/florian6973_unitsummarywidget
Browse files Browse the repository at this point in the history
Add subwidget parameters for UnitSummaryWidget
  • Loading branch information
alejoe91 authored Sep 12, 2024
2 parents 3d725a4 + 01485d9 commit 48b2131
Showing 1 changed file with 33 additions and 3 deletions.
36 changes: 33 additions & 3 deletions src/spikeinterface/widgets/unit_summary.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations
from collections import defaultdict

import numpy as np

Expand All @@ -17,7 +18,7 @@ class UnitSummaryWidget(BaseWidget):
"""
Plot a unit summary.
If amplitudes are alreday computed they are displayed.
If amplitudes are alreday computed, they are displayed.
Parameters
----------
Expand All @@ -30,6 +31,14 @@ class UnitSummaryWidget(BaseWidget):
sparsity : ChannelSparsity or None, default: None
Optional ChannelSparsity to apply.
If SortingAnalyzer is already sparse, the argument is ignored
subwidget_kwargs : dict or None, default: None
Parameters for the subwidgets in a nested dictionary
unit_locations : UnitLocationsWidget (see UnitLocationsWidget for details)
unit_waveforms : UnitWaveformsWidget (see UnitWaveformsWidget for details)
unit_waveform_density_map : UnitWaveformDensityMapWidget (see UnitWaveformDensityMapWidget for details)
autocorrelograms : AutoCorrelogramsWidget (see AutoCorrelogramsWidget for details)
amplitudes : AmplitudesWidget (see AmplitudesWidget for details)
Please note that the unit_colors should not be set in subwidget_kwargs, but directly as a parameter of plot_unit_summary.
"""

# possible_backends = {}
Expand All @@ -40,21 +49,29 @@ def __init__(
unit_id,
unit_colors=None,
sparsity=None,
radius_um=100,
subwidget_kwargs=None,
backend=None,
**backend_kwargs,
):

sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer)

if unit_colors is None:
unit_colors = get_unit_colors(sorting_analyzer)

if subwidget_kwargs is None:
subwidget_kwargs = dict()
for kwargs in subwidget_kwargs.values():
if "unit_colors" in kwargs:
raise ValueError(
"unit_colors should not be set in subwidget_kwargs, but directly as a parameter of plot_unit_summary"
)

plot_data = dict(
sorting_analyzer=sorting_analyzer,
unit_id=unit_id,
unit_colors=unit_colors,
sparsity=sparsity,
subwidget_kwargs=subwidget_kwargs,
)

BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs)
Expand All @@ -70,6 +87,14 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
unit_colors = dp.unit_colors
sparsity = dp.sparsity

# defaultdict returns empty dict if key not found in subwidget_kwargs
subwidget_kwargs = defaultdict(lambda: dict(), dp.subwidget_kwargs)
unitlocationswidget_kwargs = subwidget_kwargs["unit_locations"]
unitwaveformswidget_kwargs = subwidget_kwargs["unit_waveforms"]
unitwaveformdensitymapwidget_kwargs = subwidget_kwargs["unit_waveform_density_map"]
autocorrelogramswidget_kwargs = subwidget_kwargs["autocorrelograms"]
amplitudeswidget_kwargs = subwidget_kwargs["amplitudes"]

# force the figure without axes
if "figsize" not in backend_kwargs:
backend_kwargs["figsize"] = (18, 7)
Expand Down Expand Up @@ -99,6 +124,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
plot_legend=False,
backend="matplotlib",
ax=ax1,
**unitlocationswidget_kwargs,
)

unit_locations = sorting_analyzer.get_extension("unit_locations").get_data(outputs="by_unit")
Expand All @@ -121,6 +147,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
sparsity=sparsity,
backend="matplotlib",
ax=ax2,
**unitwaveformswidget_kwargs,
)

ax2.set_title(None)
Expand All @@ -134,6 +161,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
same_axis=False,
backend="matplotlib",
ax=ax3,
**unitwaveformdensitymapwidget_kwargs,
)
ax3.set_ylabel(None)

Expand All @@ -145,6 +173,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
unit_colors=unit_colors,
backend="matplotlib",
ax=ax4,
**autocorrelogramswidget_kwargs,
)

ax4.set_title(None)
Expand All @@ -162,6 +191,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
plot_histograms=True,
backend="matplotlib",
axes=axes,
**amplitudeswidget_kwargs,
)

fig.suptitle(f"unit_id: {dp.unit_id}")

0 comments on commit 48b2131

Please sign in to comment.