diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index f1f85b7dc3..fd7dafd5d6 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -18,26 +18,27 @@ class UnitSummaryWidget(BaseWidget): """ Plot a unit summary. - If amplitudes are alreday computed they are displayed. + If amplitudes are alreday computed, they are displayed. Parameters ---------- - sorting_analyzer : SortingAnalyzer + sorting_analyzer: SortingAnalyzer The SortingAnalyzer object - unit_id : int or str + unit_id: int or str The unit id to plot the summary of - unit_colors : dict or None, default: None + unit_colors: dict or None, default: None If given, a dictionary with unit ids as keys and colors as values, - sparsity : ChannelSparsity or None, default: None + sparsity: ChannelSparsity or None, default: None Optional ChannelSparsity to apply. If SortingAnalyzer is already sparse, the argument is ignored - widget_params : dict or None, default: None + subwidget_kwargs: dict or None, default: None Parameters for the subwidgets in a nested dictionary unit_locations: UnitLocationsWidget (see UnitLocationsWidget for details) unit_waveforms: UnitWaveformsWidget (see UnitWaveformsWidget for details) unit_waveform_density_map: UnitWaveformDensityMapWidget (see UnitWaveformDensityMapWidget for details) autocorrelograms: AutoCorrelogramsWidget (see AutoCorrelogramsWidget for details) amplitudes: AmplitudesWidget (see AmplitudesWidget for details) + Please note that the unit_colors should not be set in subwidget_kwargs, but directly as a parameter of plot_unit_summary. """ # possible_backends = {} @@ -48,7 +49,7 @@ def __init__( unit_id, unit_colors=None, sparsity=None, - widget_params=None, + subwidget_kwargs=None, backend=None, **backend_kwargs, ): @@ -57,15 +58,18 @@ def __init__( if unit_colors is None: unit_colors = get_unit_colors(sorting_analyzer) - if widget_params is None: - widget_params = dict() + if subwidget_kwargs is None: + subwidget_kwargs = dict() + for kwargs in subwidget_kwargs.values(): + if "unit_colors" in kwargs: + raise ValueError("unit_colors should not be set in subwidget_kwargs, but directly as a parameter of plot_unit_summary") plot_data = dict( sorting_analyzer=sorting_analyzer, unit_id=unit_id, unit_colors=unit_colors, sparsity=sparsity, - widget_params=widget_params, + subwidget_kwargs=subwidget_kwargs, ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) @@ -81,12 +85,12 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): unit_colors = dp.unit_colors sparsity = dp.sparsity - widget_params = defaultdict(lambda: dict(), dp.widget_params) - unitlocationswidget_params = widget_params["unit_locations"] - unitwaveformswidget_params = widget_params["unit_waveforms"] - unitwaveformdensitymapwidget_params = widget_params["unit_waveform_density_map"] - autocorrelogramswidget_params = widget_params["autocorrelograms"] - amplitudeswidget_params = widget_params["amplitudes"] + subwidget_kwargs = defaultdict(lambda: dict(), dp.subwidget_kwargs) + unitlocationswidget_kwargs = subwidget_kwargs["unit_locations"] + unitwaveformswidget_kwargs = subwidget_kwargs["unit_waveforms"] + unitwaveformdensitymapwidget_kwargs = subwidget_kwargs["unit_waveform_density_map"] + autocorrelogramswidget_kwargs = subwidget_kwargs["autocorrelograms"] + amplitudeswidget_kwargs = subwidget_kwargs["amplitudes"] # force the figure without axes if "figsize" not in backend_kwargs: @@ -117,7 +121,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): plot_legend=False, backend="matplotlib", ax=ax1, - **unitlocationswidget_params, + **unitlocationswidget_kwargs, ) unit_locations = sorting_analyzer.get_extension("unit_locations").get_data(outputs="by_unit") @@ -140,7 +144,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): sparsity=sparsity, backend="matplotlib", ax=ax2, - **unitwaveformswidget_params, + **unitwaveformswidget_kwargs, ) ax2.set_title(None) @@ -154,7 +158,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): same_axis=False, backend="matplotlib", ax=ax3, - **unitwaveformdensitymapwidget_params, + **unitwaveformdensitymapwidget_kwargs, ) ax3.set_ylabel(None) @@ -166,7 +170,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): unit_colors=unit_colors, backend="matplotlib", ax=ax4, - **autocorrelogramswidget_params, + **autocorrelogramswidget_kwargs, ) ax4.set_title(None) @@ -184,7 +188,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): plot_histograms=True, backend="matplotlib", axes=axes, - **amplitudeswidget_params, + **amplitudeswidget_kwargs, ) fig.suptitle(f"unit_id: {dp.unit_id}")