From bf113ab4a235d635716cd9fd099cefb8315b0694 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Wed, 15 May 2024 13:34:22 +0200 Subject: [PATCH 1/6] Fix bug in plot templates Fixes a bug where specifying the `channel_ids` with a non-sparse analyzer would fail. --- src/spikeinterface/widgets/unit_waveforms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/widgets/unit_waveforms.py b/src/spikeinterface/widgets/unit_waveforms.py index f701f9a868..9e46c8f5fc 100644 --- a/src/spikeinterface/widgets/unit_waveforms.py +++ b/src/spikeinterface/widgets/unit_waveforms.py @@ -134,11 +134,11 @@ def __init__( else: if sparsity is None: # in this case, we construct a dense sparsity - unit_id_to_channel_ids = {u: sorting_analyzer.channel_ids for u in sorting_analyzer.unit_ids} + unit_id_to_channel_ids = {u: channel_ids for u in sorting_analyzer.unit_ids} sparsity = ChannelSparsity.from_unit_id_to_channel_ids( unit_id_to_channel_ids=unit_id_to_channel_ids, unit_ids=sorting_analyzer.unit_ids, - channel_ids=sorting_analyzer.channel_ids, + channel_ids=channel_ids, ) else: assert isinstance(sparsity, ChannelSparsity), "'sparsity' should be a ChannelSparsity object!" From 1360d8a41f9723463ef94cd0c61ecfae537f2341 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Wed, 15 May 2024 13:49:40 +0200 Subject: [PATCH 2/6] Fixed bug --- src/spikeinterface/widgets/unit_waveforms.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/widgets/unit_waveforms.py b/src/spikeinterface/widgets/unit_waveforms.py index 9e46c8f5fc..6fcdfad5bf 100644 --- a/src/spikeinterface/widgets/unit_waveforms.py +++ b/src/spikeinterface/widgets/unit_waveforms.py @@ -115,9 +115,8 @@ def __init__( if unit_colors is None: unit_colors = get_unit_colors(sorting) - channel_locations = sorting_analyzer.get_channel_locations()[ - sorting_analyzer.channel_ids_to_indices(channel_ids) - ] + channel_indices = sorting_analyzer.channel_ids_to_indices(channel_ids) + channel_locations = sorting_analyzer.get_channel_locations()[channel_indices] extra_sparsity = False if sorting_analyzer.is_sparse(): @@ -146,7 +145,7 @@ def __init__( # get templates self.templates_ext = sorting_analyzer.get_extension("templates") assert self.templates_ext is not None, "plot_waveforms() need extension 'templates'" - templates = self.templates_ext.get_templates(unit_ids=unit_ids, operator="average") + templates = self.templates_ext.get_templates(unit_ids=unit_ids, operator="average")[:, :, channel_indices] if templates_percentile_shading is not None and not sorting_analyzer.has_extension("waveforms"): warn( From 22887031f8c7ed5fd3778e46105d2f224d7dd3f1 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 8 Jul 2024 12:52:41 +0200 Subject: [PATCH 3/6] Fix tests --- src/spikeinterface/widgets/tests/test_widgets.py | 6 +++--- src/spikeinterface/widgets/unit_waveforms.py | 7 ++++++- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 012b1ac07c..91da7da5f5 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -190,8 +190,8 @@ def test_plot_unit_waveforms(self): backend=backend, **self.backend_kwargs[backend], ) - # test "larger" sparsity - with self.assertRaises(AssertionError): + # test warning with "larger" sparsity + with self.assertWarns(UserWarning): sw.plot_unit_waveforms( self.sorting_analyzer_sparse, sparsity=self.sparsity_large, @@ -270,7 +270,7 @@ def test_plot_unit_templates(self): **self.backend_kwargs[backend], ) # test "larger" sparsity - with self.assertRaises(AssertionError): + with self.assertWarns(UserWarning): sw.plot_unit_templates( self.sorting_analyzer_sparse, sparsity=self.sparsity_large, diff --git a/src/spikeinterface/widgets/unit_waveforms.py b/src/spikeinterface/widgets/unit_waveforms.py index 24def45a85..c60d788283 100644 --- a/src/spikeinterface/widgets/unit_waveforms.py +++ b/src/spikeinterface/widgets/unit_waveforms.py @@ -159,6 +159,9 @@ def __init__( else: assert isinstance(sparsity, ChannelSparsity), "'sparsity' should be a ChannelSparsity object!" + if channel_ids is None: + channel_ids = sorting_analyzer_or_templates.channel_ids + # assert provided sparsity is a subset of waveform sparsity if extra_sparsity: combined_mask = np.logical_or(analyzer_sparsity.mask, sparsity.mask) @@ -195,12 +198,15 @@ def __init__( wfs_by_ids = self._get_wfs_by_ids( sorting_analyzer_or_templates, unit_ids, sparsity, extra_sparsity=extra_sparsity ) + else: + wfs_by_ids = None plot_data = dict( sorting_analyzer_or_templates=sorting_analyzer_or_templates, sampling_frequency=sorting_analyzer_or_templates.sampling_frequency, nbefore=nbefore, unit_ids=unit_ids, + channel_ids=channel_ids, sparsity=sparsity, unit_colors=unit_colors, channel_locations=channel_locations, @@ -269,7 +275,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): # plot waveforms if dp.plot_waveforms: wfs = dp.wfs_by_ids[unit_id] * dp.scale - print(wfs.shape) if dp.unit_selected_waveforms is not None: wfs = wfs[dp.unit_selected_waveforms[unit_id]] elif dp.max_spikes_per_unit is not None: From c7c733c8918342d7ef5629064543c8f68d5631e7 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 12 Jul 2024 13:08:12 +0200 Subject: [PATCH 4/6] Add extra-slicing of waveforms --- src/spikeinterface/widgets/unit_waveforms.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/widgets/unit_waveforms.py b/src/spikeinterface/widgets/unit_waveforms.py index c60d788283..173f8fbc16 100644 --- a/src/spikeinterface/widgets/unit_waveforms.py +++ b/src/spikeinterface/widgets/unit_waveforms.py @@ -514,11 +514,17 @@ def _get_wfs_by_ids(self, sorting_analyzer, unit_ids, sparsity, extra_sparsity=F wfs = wf_ext.get_waveforms_one_unit(unit_id) wfs = wfs[:, :, sparsity.mask[unit_index]] else: - # in this case we have to slice the dense waveforms based on the extra sparsity - # first get the sparse waveforms - wfs = wf_ext.get_waveforms_one_unit(unit_id, force_dense=True) - # apply extra sparsity - wfs = wfs[:, :, sparsity.mask[unit_index]] + # in this case we have to construct waveforms based on the extra sparsity and add the + # sparse waveforms on the valid channels + wfs_orig = wf_ext.get_waveforms_one_unit(unit_id, force_dense=False) + wfs = np.zeros( + (wfs_orig.shape[0], wfs_orig.shape[1], sparsity.mask[unit_index].sum()), dtype=wfs_orig.dtype + ) + # fill in the existing waveforms channels + valid_wfs_indices = sparsity.mask[unit_index][sorting_analyzer.sparsity.mask[unit_index]] + valid_extra_indices = sorting_analyzer.sparsity.mask[unit_index][sparsity.mask[unit_index]] + wfs[:, :, valid_extra_indices] = wfs_orig[:, :, valid_wfs_indices] + wfs_by_ids[unit_id] = wfs return wfs_by_ids From 35a6055f752423eee3e0da9e39ffe1f6f380e20a Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 12 Jul 2024 13:11:15 +0200 Subject: [PATCH 5/6] add tests --- .../widgets/tests/test_widgets.py | 31 +++++++++++++------ 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 91da7da5f5..a09304dc86 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -190,6 +190,14 @@ def test_plot_unit_waveforms(self): backend=backend, **self.backend_kwargs[backend], ) + # channel ids + sw.plot_unit_waveforms( + self.sorting_analyzer_sparse, + channel_ids=self.sorting_analyzer_sparse.channel_ids[::3], + unit_ids=unit_ids, + backend=backend, + **self.backend_kwargs[backend], + ) # test warning with "larger" sparsity with self.assertWarns(UserWarning): sw.plot_unit_waveforms( @@ -205,10 +213,10 @@ def test_plot_unit_templates(self): for backend in possible_backends: if backend not in self.skip_backends: print(f"Testing backend {backend}") - print("Dense") + # dense sw.plot_unit_templates(self.sorting_analyzer_dense, backend=backend, **self.backend_kwargs[backend]) unit_ids = self.sorting.unit_ids[:6] - print("Dense + radius") + # dense + radius sw.plot_unit_templates( self.sorting_analyzer_dense, sparsity=self.sparsity_radius, @@ -216,7 +224,7 @@ def test_plot_unit_templates(self): backend=backend, **self.backend_kwargs[backend], ) - print("Dense + best") + # dense + best sw.plot_unit_templates( self.sorting_analyzer_dense, sparsity=self.sparsity_best, @@ -225,7 +233,6 @@ def test_plot_unit_templates(self): **self.backend_kwargs[backend], ) # test different shadings - print("Sparse") sw.plot_unit_templates( self.sorting_analyzer_sparse, unit_ids=unit_ids, @@ -233,7 +240,6 @@ def test_plot_unit_templates(self): backend=backend, **self.backend_kwargs[backend], ) - print("Sparse2") sw.plot_unit_templates( self.sorting_analyzer_sparse, unit_ids=unit_ids, @@ -242,8 +248,6 @@ def test_plot_unit_templates(self): backend=backend, **self.backend_kwargs[backend], ) - # test different shadings - print("Sparse3") sw.plot_unit_templates( self.sorting_analyzer_sparse, unit_ids=unit_ids, @@ -252,7 +256,6 @@ def test_plot_unit_templates(self): shade_templates=False, **self.backend_kwargs[backend], ) - print("Sparse4") sw.plot_unit_templates( self.sorting_analyzer_sparse, unit_ids=unit_ids, @@ -260,7 +263,7 @@ def test_plot_unit_templates(self): backend=backend, **self.backend_kwargs[backend], ) - print("Extra sparsity") + # extra sparsity sw.plot_unit_templates( self.sorting_analyzer_sparse, sparsity=self.sparsity_strict, @@ -269,6 +272,16 @@ def test_plot_unit_templates(self): backend=backend, **self.backend_kwargs[backend], ) + # channel ids + sw.plot_unit_templates( + self.sorting_analyzer_sparse, + channel_ids=self.sorting_analyzer_sparse.channel_ids[::3], + unit_ids=unit_ids, + templates_percentile_shading=[1, 10, 90, 99], + backend=backend, + **self.backend_kwargs[backend], + ) + # test "larger" sparsity with self.assertWarns(UserWarning): sw.plot_unit_templates( From 600f25ece07aace3544d66f5511d8b26743fb908 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 15 Jul 2024 13:07:14 +0200 Subject: [PATCH 6/6] Simplify sparsity handling in plot waveforms/templates and fix plot_traces sortingview --- src/spikeinterface/widgets/unit_templates.py | 4 +- src/spikeinterface/widgets/unit_waveforms.py | 77 +++++++++----------- src/spikeinterface/widgets/utils.py | 1 + 3 files changed, 38 insertions(+), 44 deletions(-) diff --git a/src/spikeinterface/widgets/unit_templates.py b/src/spikeinterface/widgets/unit_templates.py index eb9a90d1d1..258ca2adaa 100644 --- a/src/spikeinterface/widgets/unit_templates.py +++ b/src/spikeinterface/widgets/unit_templates.py @@ -24,8 +24,8 @@ def plot_sortingview(self, data_plot, **backend_kwargs): assert len(dp.templates_shading) <= 4, "Only 2 ans 4 templates shading are supported in sortingview" # ensure serializable for sortingview - unit_id_to_channel_ids = dp.sparsity.unit_id_to_channel_ids - unit_id_to_channel_indices = dp.sparsity.unit_id_to_channel_indices + unit_id_to_channel_ids = dp.final_sparsity.unit_id_to_channel_ids + unit_id_to_channel_indices = dp.final_sparsity.unit_id_to_channel_indices unit_ids, channel_ids = make_serializable(dp.unit_ids, dp.channel_ids) diff --git a/src/spikeinterface/widgets/unit_waveforms.py b/src/spikeinterface/widgets/unit_waveforms.py index 173f8fbc16..c593836061 100644 --- a/src/spikeinterface/widgets/unit_waveforms.py +++ b/src/spikeinterface/widgets/unit_waveforms.py @@ -123,7 +123,7 @@ def __init__( unit_colors = get_unit_colors(sorting_analyzer_or_templates) channel_locations = sorting_analyzer_or_templates.get_channel_locations() - extra_sparsity = False + extra_sparsity = None # handle sparsity sparsity_mismatch_warning = ( "The provided 'sparsity' includes additional channels not in the analyzer sparsity. " @@ -131,43 +131,39 @@ def __init__( ) analyzer_sparsity = sorting_analyzer_or_templates.sparsity if channel_ids is not None: + assert sparsity is None, "If 'channel_ids' is provided, 'sparsity' should be None!" channel_mask = np.tile( np.isin(sorting_analyzer_or_templates.channel_ids, channel_ids), (len(sorting_analyzer_or_templates.unit_ids), 1), ) - sparsity = ChannelSparsity( + extra_sparsity = ChannelSparsity( mask=channel_mask, channel_ids=sorting_analyzer_or_templates.channel_ids, unit_ids=sorting_analyzer_or_templates.unit_ids, ) - extra_sparsity = True - elif analyzer_sparsity is not None: - if sparsity is None: - sparsity = analyzer_sparsity - else: - extra_sparsity = True - else: - if sparsity is None: - unit_id_to_channel_ids = { - u: sorting_analyzer_or_templates.channel_ids for u in sorting_analyzer_or_templates.unit_ids - } - sparsity = ChannelSparsity.from_unit_id_to_channel_ids( - unit_id_to_channel_ids=unit_id_to_channel_ids, - unit_ids=sorting_analyzer_or_templates.unit_ids, - channel_ids=sorting_analyzer_or_templates.channel_ids, - ) - else: - assert isinstance(sparsity, ChannelSparsity), "'sparsity' should be a ChannelSparsity object!" + elif sparsity is not None: + extra_sparsity = sparsity if channel_ids is None: channel_ids = sorting_analyzer_or_templates.channel_ids # assert provided sparsity is a subset of waveform sparsity - if extra_sparsity: - combined_mask = np.logical_or(analyzer_sparsity.mask, sparsity.mask) - if not np.all(np.sum(combined_mask, 1) - np.sum(sorting_analyzer_or_templates.sparsity.mask, 1) == 0): + if extra_sparsity is not None and analyzer_sparsity is not None: + combined_mask = np.logical_or(analyzer_sparsity.mask, extra_sparsity.mask) + if not np.all(np.sum(combined_mask, 1) - np.sum(analyzer_sparsity.mask, 1) == 0): warn(sparsity_mismatch_warning) + final_sparsity = extra_sparsity if extra_sparsity is not None else analyzer_sparsity + if final_sparsity is None: + final_sparsity = ChannelSparsity( + mask=np.ones( + (len(sorting_analyzer_or_templates.unit_ids), len(sorting_analyzer_or_templates.channel_ids)), + dtype=bool, + ), + unit_ids=sorting_analyzer_or_templates.unit_ids, + channel_ids=sorting_analyzer_or_templates.channel_ids, + ) + # get templates if isinstance(sorting_analyzer_or_templates, Templates): templates = sorting_analyzer_or_templates.templates_array @@ -195,9 +191,7 @@ def __init__( wf_ext = sorting_analyzer_or_templates.get_extension("waveforms") if wf_ext is None: raise ValueError("plot_waveforms() needs the extension 'waveforms'") - wfs_by_ids = self._get_wfs_by_ids( - sorting_analyzer_or_templates, unit_ids, sparsity, extra_sparsity=extra_sparsity - ) + wfs_by_ids = self._get_wfs_by_ids(sorting_analyzer_or_templates, unit_ids, extra_sparsity=extra_sparsity) else: wfs_by_ids = None @@ -207,7 +201,8 @@ def __init__( nbefore=nbefore, unit_ids=unit_ids, channel_ids=channel_ids, - sparsity=sparsity, + final_sparsity=final_sparsity, + extra_sparsity=extra_sparsity, unit_colors=unit_colors, channel_locations=channel_locations, scale=scale, @@ -234,7 +229,6 @@ def __init__( alpha_templates=alpha_templates, hide_unit_selector=hide_unit_selector, plot_legend=plot_legend, - extra_sparsity=extra_sparsity, ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) @@ -269,7 +263,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax = self.axes.flatten()[i] color = dp.unit_colors[unit_id] - chan_inds = dp.sparsity.unit_id_to_channel_indices[unit_id] + chan_inds = dp.final_sparsity.unit_id_to_channel_indices[unit_id] xvectors_flat = xvectors[:, chan_inds].T.flatten() # plot waveforms @@ -501,28 +495,27 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): if backend_kwargs["display"]: display(self.widget) - def _get_wfs_by_ids(self, sorting_analyzer, unit_ids, sparsity, extra_sparsity=False): + def _get_wfs_by_ids(self, sorting_analyzer, unit_ids, extra_sparsity): wfs_by_ids = {} wf_ext = sorting_analyzer.get_extension("waveforms") for unit_id in unit_ids: unit_index = list(sorting_analyzer.unit_ids).index(unit_id) - if not extra_sparsity: - # get waveforms with default sparsity - if sorting_analyzer.is_sparse(): - wfs = wf_ext.get_waveforms_one_unit(unit_id, force_dense=False) - else: - wfs = wf_ext.get_waveforms_one_unit(unit_id) - wfs = wfs[:, :, sparsity.mask[unit_index]] + if extra_sparsity is None: + wfs = wf_ext.get_waveforms_one_unit(unit_id, force_dense=False) else: # in this case we have to construct waveforms based on the extra sparsity and add the # sparse waveforms on the valid channels + if sorting_analyzer.is_sparse(): + original_mask = sorting_analyzer.sparsity.mask[unit_index] + else: + original_mask = np.ones(len(sorting_analyzer.channel_ids), dtype=bool) wfs_orig = wf_ext.get_waveforms_one_unit(unit_id, force_dense=False) wfs = np.zeros( - (wfs_orig.shape[0], wfs_orig.shape[1], sparsity.mask[unit_index].sum()), dtype=wfs_orig.dtype + (wfs_orig.shape[0], wfs_orig.shape[1], extra_sparsity.mask[unit_index].sum()), dtype=wfs_orig.dtype ) # fill in the existing waveforms channels - valid_wfs_indices = sparsity.mask[unit_index][sorting_analyzer.sparsity.mask[unit_index]] - valid_extra_indices = sorting_analyzer.sparsity.mask[unit_index][sparsity.mask[unit_index]] + valid_wfs_indices = extra_sparsity.mask[unit_index][original_mask] + valid_extra_indices = original_mask[extra_sparsity.mask[unit_index]] wfs[:, :, valid_extra_indices] = wfs_orig[:, :, valid_wfs_indices] wfs_by_ids[unit_id] = wfs @@ -592,7 +585,7 @@ def _update_plot(self, change): if data_plot["plot_waveforms"]: wfs_by_ids = self._get_wfs_by_ids( - self.sorting_analyzer, unit_ids, data_plot["sparsity"], extra_sparsity=data_plot["extra_sparsity"] + self.sorting_analyzer, unit_ids, extra_sparsity=data_plot["extra_sparsity"] ) data_plot["wfs_by_ids"] = wfs_by_ids @@ -638,7 +631,7 @@ def _plot_probe(self, ax, channel_locations, unit_ids): # TODO this could be done with probeinterface plotting plotting tools!! for unit in unit_ids: - channel_inds = self.data_plot["sparsity"].unit_id_to_channel_indices[unit] + channel_inds = self.data_plot["final_sparsity"].unit_id_to_channel_indices[unit] ax.plot( channel_locations[channel_inds, 0], channel_locations[channel_inds, 1], diff --git a/src/spikeinterface/widgets/utils.py b/src/spikeinterface/widgets/utils.py index ac0676e4c7..ca09cc4d8f 100644 --- a/src/spikeinterface/widgets/utils.py +++ b/src/spikeinterface/widgets/utils.py @@ -151,6 +151,7 @@ def array_to_image( output_image : 3D numpy array """ + import matplotlib.pyplot as plt from scipy.ndimage import zoom