diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 012b1ac07c..a09304dc86 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -190,8 +190,16 @@ def test_plot_unit_waveforms(self): backend=backend, **self.backend_kwargs[backend], ) - # test "larger" sparsity - with self.assertRaises(AssertionError): + # channel ids + sw.plot_unit_waveforms( + self.sorting_analyzer_sparse, + channel_ids=self.sorting_analyzer_sparse.channel_ids[::3], + unit_ids=unit_ids, + backend=backend, + **self.backend_kwargs[backend], + ) + # test warning with "larger" sparsity + with self.assertWarns(UserWarning): sw.plot_unit_waveforms( self.sorting_analyzer_sparse, sparsity=self.sparsity_large, @@ -205,10 +213,10 @@ def test_plot_unit_templates(self): for backend in possible_backends: if backend not in self.skip_backends: print(f"Testing backend {backend}") - print("Dense") + # dense sw.plot_unit_templates(self.sorting_analyzer_dense, backend=backend, **self.backend_kwargs[backend]) unit_ids = self.sorting.unit_ids[:6] - print("Dense + radius") + # dense + radius sw.plot_unit_templates( self.sorting_analyzer_dense, sparsity=self.sparsity_radius, @@ -216,7 +224,7 @@ def test_plot_unit_templates(self): backend=backend, **self.backend_kwargs[backend], ) - print("Dense + best") + # dense + best sw.plot_unit_templates( self.sorting_analyzer_dense, sparsity=self.sparsity_best, @@ -225,7 +233,6 @@ def test_plot_unit_templates(self): **self.backend_kwargs[backend], ) # test different shadings - print("Sparse") sw.plot_unit_templates( self.sorting_analyzer_sparse, unit_ids=unit_ids, @@ -233,7 +240,6 @@ def test_plot_unit_templates(self): backend=backend, **self.backend_kwargs[backend], ) - print("Sparse2") sw.plot_unit_templates( self.sorting_analyzer_sparse, unit_ids=unit_ids, @@ -242,8 +248,6 @@ def test_plot_unit_templates(self): backend=backend, **self.backend_kwargs[backend], ) - # test different shadings - print("Sparse3") sw.plot_unit_templates( self.sorting_analyzer_sparse, unit_ids=unit_ids, @@ -252,7 +256,6 @@ def test_plot_unit_templates(self): shade_templates=False, **self.backend_kwargs[backend], ) - print("Sparse4") sw.plot_unit_templates( self.sorting_analyzer_sparse, unit_ids=unit_ids, @@ -260,7 +263,7 @@ def test_plot_unit_templates(self): backend=backend, **self.backend_kwargs[backend], ) - print("Extra sparsity") + # extra sparsity sw.plot_unit_templates( self.sorting_analyzer_sparse, sparsity=self.sparsity_strict, @@ -269,8 +272,18 @@ def test_plot_unit_templates(self): backend=backend, **self.backend_kwargs[backend], ) + # channel ids + sw.plot_unit_templates( + self.sorting_analyzer_sparse, + channel_ids=self.sorting_analyzer_sparse.channel_ids[::3], + unit_ids=unit_ids, + templates_percentile_shading=[1, 10, 90, 99], + backend=backend, + **self.backend_kwargs[backend], + ) + # test "larger" sparsity - with self.assertRaises(AssertionError): + with self.assertWarns(UserWarning): sw.plot_unit_templates( self.sorting_analyzer_sparse, sparsity=self.sparsity_large, diff --git a/src/spikeinterface/widgets/unit_templates.py b/src/spikeinterface/widgets/unit_templates.py index eb9a90d1d1..258ca2adaa 100644 --- a/src/spikeinterface/widgets/unit_templates.py +++ b/src/spikeinterface/widgets/unit_templates.py @@ -24,8 +24,8 @@ def plot_sortingview(self, data_plot, **backend_kwargs): assert len(dp.templates_shading) <= 4, "Only 2 ans 4 templates shading are supported in sortingview" # ensure serializable for sortingview - unit_id_to_channel_ids = dp.sparsity.unit_id_to_channel_ids - unit_id_to_channel_indices = dp.sparsity.unit_id_to_channel_indices + unit_id_to_channel_ids = dp.final_sparsity.unit_id_to_channel_ids + unit_id_to_channel_indices = dp.final_sparsity.unit_id_to_channel_indices unit_ids, channel_ids = make_serializable(dp.unit_ids, dp.channel_ids) diff --git a/src/spikeinterface/widgets/unit_waveforms.py b/src/spikeinterface/widgets/unit_waveforms.py index 59f91306ea..c593836061 100644 --- a/src/spikeinterface/widgets/unit_waveforms.py +++ b/src/spikeinterface/widgets/unit_waveforms.py @@ -119,38 +119,50 @@ def __init__( if unit_ids is None: unit_ids = sorting_analyzer_or_templates.unit_ids - if channel_ids is None: - channel_ids = sorting_analyzer_or_templates.channel_ids if unit_colors is None: unit_colors = get_unit_colors(sorting_analyzer_or_templates) - channel_indices = [list(sorting_analyzer_or_templates.channel_ids).index(ch) for ch in channel_ids] - channel_locations = sorting_analyzer_or_templates.get_channel_locations()[channel_indices] - extra_sparsity = False - if sorting_analyzer_or_templates.sparsity is not None: - if sparsity is None: - sparsity = sorting_analyzer_or_templates.sparsity - else: - # assert provided sparsity is a subset of waveform sparsity - combined_mask = np.logical_or(sorting_analyzer_or_templates.sparsity.mask, sparsity.mask) - assert np.all(np.sum(combined_mask, 1) - np.sum(sorting_analyzer_or_templates.sparsity.mask, 1) == 0), ( - "The provided 'sparsity' needs to include only the sparse channels " - "used to extract waveforms (for example, by using a smaller 'radius_um')." - ) - extra_sparsity = True - else: - if sparsity is None: - # in this case, we construct a dense sparsity - unit_id_to_channel_ids = { - u: sorting_analyzer_or_templates.channel_ids for u in sorting_analyzer_or_templates.unit_ids - } - sparsity = ChannelSparsity.from_unit_id_to_channel_ids( - unit_id_to_channel_ids=unit_id_to_channel_ids, - unit_ids=sorting_analyzer_or_templates.unit_ids, - channel_ids=sorting_analyzer_or_templates.channel_ids, - ) - else: - assert isinstance(sparsity, ChannelSparsity), "'sparsity' should be a ChannelSparsity object!" + channel_locations = sorting_analyzer_or_templates.get_channel_locations() + extra_sparsity = None + # handle sparsity + sparsity_mismatch_warning = ( + "The provided 'sparsity' includes additional channels not in the analyzer sparsity. " + "These extra channels will be plotted as flat lines." + ) + analyzer_sparsity = sorting_analyzer_or_templates.sparsity + if channel_ids is not None: + assert sparsity is None, "If 'channel_ids' is provided, 'sparsity' should be None!" + channel_mask = np.tile( + np.isin(sorting_analyzer_or_templates.channel_ids, channel_ids), + (len(sorting_analyzer_or_templates.unit_ids), 1), + ) + extra_sparsity = ChannelSparsity( + mask=channel_mask, + channel_ids=sorting_analyzer_or_templates.channel_ids, + unit_ids=sorting_analyzer_or_templates.unit_ids, + ) + elif sparsity is not None: + extra_sparsity = sparsity + + if channel_ids is None: + channel_ids = sorting_analyzer_or_templates.channel_ids + + # assert provided sparsity is a subset of waveform sparsity + if extra_sparsity is not None and analyzer_sparsity is not None: + combined_mask = np.logical_or(analyzer_sparsity.mask, extra_sparsity.mask) + if not np.all(np.sum(combined_mask, 1) - np.sum(analyzer_sparsity.mask, 1) == 0): + warn(sparsity_mismatch_warning) + + final_sparsity = extra_sparsity if extra_sparsity is not None else analyzer_sparsity + if final_sparsity is None: + final_sparsity = ChannelSparsity( + mask=np.ones( + (len(sorting_analyzer_or_templates.unit_ids), len(sorting_analyzer_or_templates.channel_ids)), + dtype=bool, + ), + unit_ids=sorting_analyzer_or_templates.unit_ids, + channel_ids=sorting_analyzer_or_templates.channel_ids, + ) # get templates if isinstance(sorting_analyzer_or_templates, Templates): @@ -174,34 +186,14 @@ def __init__( templates_percentile_shading = None templates_shading = self._get_template_shadings(unit_ids, templates_percentile_shading) - wfs_by_ids = {} if plot_waveforms: # this must be a sorting_analyzer wf_ext = sorting_analyzer_or_templates.get_extension("waveforms") if wf_ext is None: raise ValueError("plot_waveforms() needs the extension 'waveforms'") - for unit_id in unit_ids: - unit_index = list(sorting_analyzer_or_templates.unit_ids).index(unit_id) - if not extra_sparsity: - if sorting_analyzer_or_templates.is_sparse(): - # wfs = we.get_waveforms(unit_id) - wfs = wf_ext.get_waveforms_one_unit(unit_id, force_dense=False) - else: - # wfs = we.get_waveforms(unit_id, sparsity=sparsity) - wfs = wf_ext.get_waveforms_one_unit(unit_id) - wfs = wfs[:, :, sparsity.mask[unit_index]] - else: - # in this case we have to slice the waveform sparsity based on the extra sparsity - # first get the sparse waveforms - # wfs = we.get_waveforms(unit_id) - wfs = wf_ext.get_waveforms_one_unit(unit_id, force_dense=False) - # find additional slice to apply to sparse waveforms - (wfs_sparse_indices,) = np.nonzero(sorting_analyzer_or_templates.sparsity.mask[unit_index]) - (extra_sparse_indices,) = np.nonzero(sparsity.mask[unit_index]) - (extra_slice,) = np.nonzero(np.isin(wfs_sparse_indices, extra_sparse_indices)) - # apply extra sparsity - wfs = wfs[:, :, extra_slice] - wfs_by_ids[unit_id] = wfs + wfs_by_ids = self._get_wfs_by_ids(sorting_analyzer_or_templates, unit_ids, extra_sparsity=extra_sparsity) + else: + wfs_by_ids = None plot_data = dict( sorting_analyzer_or_templates=sorting_analyzer_or_templates, @@ -209,7 +201,8 @@ def __init__( nbefore=nbefore, unit_ids=unit_ids, channel_ids=channel_ids, - sparsity=sparsity, + final_sparsity=final_sparsity, + extra_sparsity=extra_sparsity, unit_colors=unit_colors, channel_locations=channel_locations, scale=scale, @@ -270,7 +263,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax = self.axes.flatten()[i] color = dp.unit_colors[unit_id] - chan_inds = dp.sparsity.unit_id_to_channel_indices[unit_id] + chan_inds = dp.final_sparsity.unit_id_to_channel_indices[unit_id] xvectors_flat = xvectors[:, chan_inds].T.flatten() # plot waveforms @@ -502,6 +495,32 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): if backend_kwargs["display"]: display(self.widget) + def _get_wfs_by_ids(self, sorting_analyzer, unit_ids, extra_sparsity): + wfs_by_ids = {} + wf_ext = sorting_analyzer.get_extension("waveforms") + for unit_id in unit_ids: + unit_index = list(sorting_analyzer.unit_ids).index(unit_id) + if extra_sparsity is None: + wfs = wf_ext.get_waveforms_one_unit(unit_id, force_dense=False) + else: + # in this case we have to construct waveforms based on the extra sparsity and add the + # sparse waveforms on the valid channels + if sorting_analyzer.is_sparse(): + original_mask = sorting_analyzer.sparsity.mask[unit_index] + else: + original_mask = np.ones(len(sorting_analyzer.channel_ids), dtype=bool) + wfs_orig = wf_ext.get_waveforms_one_unit(unit_id, force_dense=False) + wfs = np.zeros( + (wfs_orig.shape[0], wfs_orig.shape[1], extra_sparsity.mask[unit_index].sum()), dtype=wfs_orig.dtype + ) + # fill in the existing waveforms channels + valid_wfs_indices = extra_sparsity.mask[unit_index][original_mask] + valid_extra_indices = original_mask[extra_sparsity.mask[unit_index]] + wfs[:, :, valid_extra_indices] = wfs_orig[:, :, valid_wfs_indices] + + wfs_by_ids[unit_id] = wfs + return wfs_by_ids + def _get_template_shadings(self, unit_ids, templates_percentile_shading): templates = self.templates_ext.get_templates(unit_ids=unit_ids, operator="average") @@ -538,6 +557,8 @@ def _update_plot(self, change): hide_axis = self.hide_axis_button.value do_shading = self.template_shading_button.value + data_plot = self.next_data_plot + if self.sorting_analyzer is not None: templates = self.templates_ext.get_templates(unit_ids=unit_ids, operator="average") templates_shadings = self._get_template_shadings(unit_ids, data_plot["templates_percentile_shading"]) @@ -549,7 +570,6 @@ def _update_plot(self, change): channel_locations = self.templates.get_channel_locations() # matplotlib next_data_plot dict update at each call - data_plot = self.next_data_plot data_plot["unit_ids"] = unit_ids data_plot["templates"] = templates data_plot["templates_shading"] = templates_shadings @@ -564,10 +584,10 @@ def _update_plot(self, change): data_plot["scalebar"] = self.scalebar.value if data_plot["plot_waveforms"]: - wf_ext = self.sorting_analyzer.get_extension("waveforms") - data_plot["wfs_by_ids"] = { - unit_id: wf_ext.get_waveforms_one_unit(unit_id, force_dense=False) for unit_id in unit_ids - } + wfs_by_ids = self._get_wfs_by_ids( + self.sorting_analyzer, unit_ids, extra_sparsity=data_plot["extra_sparsity"] + ) + data_plot["wfs_by_ids"] = wfs_by_ids # TODO option for plot_legend backend_kwargs = {} @@ -611,7 +631,7 @@ def _plot_probe(self, ax, channel_locations, unit_ids): # TODO this could be done with probeinterface plotting plotting tools!! for unit in unit_ids: - channel_inds = self.data_plot["sparsity"].unit_id_to_channel_indices[unit] + channel_inds = self.data_plot["final_sparsity"].unit_id_to_channel_indices[unit] ax.plot( channel_locations[channel_inds, 0], channel_locations[channel_inds, 1], diff --git a/src/spikeinterface/widgets/utils.py b/src/spikeinterface/widgets/utils.py index ac0676e4c7..ca09cc4d8f 100644 --- a/src/spikeinterface/widgets/utils.py +++ b/src/spikeinterface/widgets/utils.py @@ -151,6 +151,7 @@ def array_to_image( output_image : 3D numpy array """ + import matplotlib.pyplot as plt from scipy.ndimage import zoom