Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug in plot templates #2850

Merged
merged 9 commits into from
Jul 15, 2024
37 changes: 25 additions & 12 deletions src/spikeinterface/widgets/tests/test_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,16 @@ def test_plot_unit_waveforms(self):
backend=backend,
**self.backend_kwargs[backend],
)
# test "larger" sparsity
with self.assertRaises(AssertionError):
# channel ids
sw.plot_unit_waveforms(
self.sorting_analyzer_sparse,
channel_ids=self.sorting_analyzer_sparse.channel_ids[::3],
unit_ids=unit_ids,
backend=backend,
**self.backend_kwargs[backend],
)
# test warning with "larger" sparsity
with self.assertWarns(UserWarning):
sw.plot_unit_waveforms(
self.sorting_analyzer_sparse,
sparsity=self.sparsity_large,
Expand All @@ -205,18 +213,18 @@ def test_plot_unit_templates(self):
for backend in possible_backends:
if backend not in self.skip_backends:
print(f"Testing backend {backend}")
print("Dense")
# dense
sw.plot_unit_templates(self.sorting_analyzer_dense, backend=backend, **self.backend_kwargs[backend])
unit_ids = self.sorting.unit_ids[:6]
print("Dense + radius")
# dense + radius
sw.plot_unit_templates(
self.sorting_analyzer_dense,
sparsity=self.sparsity_radius,
unit_ids=unit_ids,
backend=backend,
**self.backend_kwargs[backend],
)
print("Dense + best")
# dense + best
sw.plot_unit_templates(
self.sorting_analyzer_dense,
sparsity=self.sparsity_best,
Expand All @@ -225,15 +233,13 @@ def test_plot_unit_templates(self):
**self.backend_kwargs[backend],
)
# test different shadings
print("Sparse")
sw.plot_unit_templates(
self.sorting_analyzer_sparse,
unit_ids=unit_ids,
templates_percentile_shading=None,
backend=backend,
**self.backend_kwargs[backend],
)
print("Sparse2")
sw.plot_unit_templates(
self.sorting_analyzer_sparse,
unit_ids=unit_ids,
Expand All @@ -242,8 +248,6 @@ def test_plot_unit_templates(self):
backend=backend,
**self.backend_kwargs[backend],
)
# test different shadings
print("Sparse3")
sw.plot_unit_templates(
self.sorting_analyzer_sparse,
unit_ids=unit_ids,
Expand All @@ -252,15 +256,14 @@ def test_plot_unit_templates(self):
shade_templates=False,
**self.backend_kwargs[backend],
)
print("Sparse4")
sw.plot_unit_templates(
self.sorting_analyzer_sparse,
unit_ids=unit_ids,
templates_percentile_shading=0.1,
backend=backend,
**self.backend_kwargs[backend],
)
print("Extra sparsity")
# extra sparsity
sw.plot_unit_templates(
self.sorting_analyzer_sparse,
sparsity=self.sparsity_strict,
Expand All @@ -269,8 +272,18 @@ def test_plot_unit_templates(self):
backend=backend,
**self.backend_kwargs[backend],
)
# channel ids
sw.plot_unit_templates(
self.sorting_analyzer_sparse,
channel_ids=self.sorting_analyzer_sparse.channel_ids[::3],
unit_ids=unit_ids,
templates_percentile_shading=[1, 10, 90, 99],
backend=backend,
**self.backend_kwargs[backend],
)

# test "larger" sparsity
with self.assertRaises(AssertionError):
with self.assertWarns(UserWarning):
sw.plot_unit_templates(
self.sorting_analyzer_sparse,
sparsity=self.sparsity_large,
Expand Down
109 changes: 68 additions & 41 deletions src/spikeinterface/widgets/unit_waveforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,28 +119,35 @@ def __init__(

if unit_ids is None:
unit_ids = sorting_analyzer_or_templates.unit_ids
if channel_ids is None:
channel_ids = sorting_analyzer_or_templates.channel_ids
if unit_colors is None:
unit_colors = get_unit_colors(sorting_analyzer_or_templates)

channel_indices = [list(sorting_analyzer_or_templates.channel_ids).index(ch) for ch in channel_ids]
channel_locations = sorting_analyzer_or_templates.get_channel_locations()[channel_indices]
channel_locations = sorting_analyzer_or_templates.get_channel_locations()
extra_sparsity = False
if sorting_analyzer_or_templates.sparsity is not None:
# handle sparsity
sparsity_mismatch_warning = (
"The provided 'sparsity' includes additional channels not in the analyzer sparsity. "
"These extra channels will be plotted as flat lines."
)
analyzer_sparsity = sorting_analyzer_or_templates.sparsity
if channel_ids is not None:
channel_mask = np.tile(
np.isin(sorting_analyzer_or_templates.channel_ids, channel_ids),
(len(sorting_analyzer_or_templates.unit_ids), 1),
)
sparsity = ChannelSparsity(
mask=channel_mask,
channel_ids=sorting_analyzer_or_templates.channel_ids,
unit_ids=sorting_analyzer_or_templates.unit_ids,
)
extra_sparsity = True
elif analyzer_sparsity is not None:
if sparsity is None:
sparsity = sorting_analyzer_or_templates.sparsity
sparsity = analyzer_sparsity
else:
# assert provided sparsity is a subset of waveform sparsity
combined_mask = np.logical_or(sorting_analyzer_or_templates.sparsity.mask, sparsity.mask)
assert np.all(np.sum(combined_mask, 1) - np.sum(sorting_analyzer_or_templates.sparsity.mask, 1) == 0), (
"The provided 'sparsity' needs to include only the sparse channels "
"used to extract waveforms (for example, by using a smaller 'radius_um')."
)
extra_sparsity = True
else:
if sparsity is None:
# in this case, we construct a dense sparsity
unit_id_to_channel_ids = {
u: sorting_analyzer_or_templates.channel_ids for u in sorting_analyzer_or_templates.unit_ids
}
Expand All @@ -152,6 +159,15 @@ def __init__(
else:
assert isinstance(sparsity, ChannelSparsity), "'sparsity' should be a ChannelSparsity object!"

if channel_ids is None:
channel_ids = sorting_analyzer_or_templates.channel_ids

# assert provided sparsity is a subset of waveform sparsity
if extra_sparsity:
combined_mask = np.logical_or(analyzer_sparsity.mask, sparsity.mask)
if not np.all(np.sum(combined_mask, 1) - np.sum(sorting_analyzer_or_templates.sparsity.mask, 1) == 0):
warn(sparsity_mismatch_warning)

# get templates
if isinstance(sorting_analyzer_or_templates, Templates):
templates = sorting_analyzer_or_templates.templates_array
Expand All @@ -174,34 +190,16 @@ def __init__(
templates_percentile_shading = None
templates_shading = self._get_template_shadings(unit_ids, templates_percentile_shading)

wfs_by_ids = {}
if plot_waveforms:
# this must be a sorting_analyzer
wf_ext = sorting_analyzer_or_templates.get_extension("waveforms")
if wf_ext is None:
raise ValueError("plot_waveforms() needs the extension 'waveforms'")
for unit_id in unit_ids:
unit_index = list(sorting_analyzer_or_templates.unit_ids).index(unit_id)
if not extra_sparsity:
if sorting_analyzer_or_templates.is_sparse():
# wfs = we.get_waveforms(unit_id)
wfs = wf_ext.get_waveforms_one_unit(unit_id, force_dense=False)
else:
# wfs = we.get_waveforms(unit_id, sparsity=sparsity)
wfs = wf_ext.get_waveforms_one_unit(unit_id)
wfs = wfs[:, :, sparsity.mask[unit_index]]
else:
# in this case we have to slice the waveform sparsity based on the extra sparsity
# first get the sparse waveforms
# wfs = we.get_waveforms(unit_id)
wfs = wf_ext.get_waveforms_one_unit(unit_id, force_dense=False)
# find additional slice to apply to sparse waveforms
(wfs_sparse_indices,) = np.nonzero(sorting_analyzer_or_templates.sparsity.mask[unit_index])
(extra_sparse_indices,) = np.nonzero(sparsity.mask[unit_index])
(extra_slice,) = np.nonzero(np.isin(wfs_sparse_indices, extra_sparse_indices))
# apply extra sparsity
wfs = wfs[:, :, extra_slice]
wfs_by_ids[unit_id] = wfs
wfs_by_ids = self._get_wfs_by_ids(
sorting_analyzer_or_templates, unit_ids, sparsity, extra_sparsity=extra_sparsity
)
else:
wfs_by_ids = None

plot_data = dict(
sorting_analyzer_or_templates=sorting_analyzer_or_templates,
Expand Down Expand Up @@ -236,6 +234,7 @@ def __init__(
alpha_templates=alpha_templates,
hide_unit_selector=hide_unit_selector,
plot_legend=plot_legend,
extra_sparsity=extra_sparsity,
)
BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs)

Expand Down Expand Up @@ -502,6 +501,33 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):
if backend_kwargs["display"]:
display(self.widget)

def _get_wfs_by_ids(self, sorting_analyzer, unit_ids, sparsity, extra_sparsity=False):
wfs_by_ids = {}
wf_ext = sorting_analyzer.get_extension("waveforms")
for unit_id in unit_ids:
unit_index = list(sorting_analyzer.unit_ids).index(unit_id)
if not extra_sparsity:
# get waveforms with default sparsity
if sorting_analyzer.is_sparse():
wfs = wf_ext.get_waveforms_one_unit(unit_id, force_dense=False)
else:
wfs = wf_ext.get_waveforms_one_unit(unit_id)
wfs = wfs[:, :, sparsity.mask[unit_index]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this if not extra_sparsity ? does the sparsity represent the extra sparsity ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extra sparsity is a bool, in that case the sparsity is the user provided sparsity (or from channel ids)

else:
# in this case we have to construct waveforms based on the extra sparsity and add the
# sparse waveforms on the valid channels
wfs_orig = wf_ext.get_waveforms_one_unit(unit_id, force_dense=False)
wfs = np.zeros(
(wfs_orig.shape[0], wfs_orig.shape[1], sparsity.mask[unit_index].sum()), dtype=wfs_orig.dtype
)
# fill in the existing waveforms channels
valid_wfs_indices = sparsity.mask[unit_index][sorting_analyzer.sparsity.mask[unit_index]]
valid_extra_indices = sorting_analyzer.sparsity.mask[unit_index][sparsity.mask[unit_index]]
wfs[:, :, valid_extra_indices] = wfs_orig[:, :, valid_wfs_indices]

wfs_by_ids[unit_id] = wfs
return wfs_by_ids

def _get_template_shadings(self, unit_ids, templates_percentile_shading):
templates = self.templates_ext.get_templates(unit_ids=unit_ids, operator="average")

Expand Down Expand Up @@ -538,6 +564,8 @@ def _update_plot(self, change):
hide_axis = self.hide_axis_button.value
do_shading = self.template_shading_button.value

data_plot = self.next_data_plot

if self.sorting_analyzer is not None:
templates = self.templates_ext.get_templates(unit_ids=unit_ids, operator="average")
templates_shadings = self._get_template_shadings(unit_ids, data_plot["templates_percentile_shading"])
Expand All @@ -549,7 +577,6 @@ def _update_plot(self, change):
channel_locations = self.templates.get_channel_locations()

# matplotlib next_data_plot dict update at each call
data_plot = self.next_data_plot
data_plot["unit_ids"] = unit_ids
data_plot["templates"] = templates
data_plot["templates_shading"] = templates_shadings
Expand All @@ -564,10 +591,10 @@ def _update_plot(self, change):
data_plot["scalebar"] = self.scalebar.value

if data_plot["plot_waveforms"]:
wf_ext = self.sorting_analyzer.get_extension("waveforms")
data_plot["wfs_by_ids"] = {
unit_id: wf_ext.get_waveforms_one_unit(unit_id, force_dense=False) for unit_id in unit_ids
}
wfs_by_ids = self._get_wfs_by_ids(
self.sorting_analyzer, unit_ids, data_plot["sparsity"], extra_sparsity=data_plot["extra_sparsity"]
)
data_plot["wfs_by_ids"] = wfs_by_ids

# TODO option for plot_legend
backend_kwargs = {}
Expand Down