Skip to content

Commit

Permalink
Merge pull request #2850 from DradeAW/patch-1
Browse files Browse the repository at this point in the history
Fix bug in plot templates
  • Loading branch information
samuelgarcia authored Jul 15, 2024
2 parents c9fc8e1 + 600f25e commit ab59a93
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 74 deletions.
37 changes: 25 additions & 12 deletions src/spikeinterface/widgets/tests/test_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,16 @@ def test_plot_unit_waveforms(self):
backend=backend,
**self.backend_kwargs[backend],
)
# test "larger" sparsity
with self.assertRaises(AssertionError):
# channel ids
sw.plot_unit_waveforms(
self.sorting_analyzer_sparse,
channel_ids=self.sorting_analyzer_sparse.channel_ids[::3],
unit_ids=unit_ids,
backend=backend,
**self.backend_kwargs[backend],
)
# test warning with "larger" sparsity
with self.assertWarns(UserWarning):
sw.plot_unit_waveforms(
self.sorting_analyzer_sparse,
sparsity=self.sparsity_large,
Expand All @@ -205,18 +213,18 @@ def test_plot_unit_templates(self):
for backend in possible_backends:
if backend not in self.skip_backends:
print(f"Testing backend {backend}")
print("Dense")
# dense
sw.plot_unit_templates(self.sorting_analyzer_dense, backend=backend, **self.backend_kwargs[backend])
unit_ids = self.sorting.unit_ids[:6]
print("Dense + radius")
# dense + radius
sw.plot_unit_templates(
self.sorting_analyzer_dense,
sparsity=self.sparsity_radius,
unit_ids=unit_ids,
backend=backend,
**self.backend_kwargs[backend],
)
print("Dense + best")
# dense + best
sw.plot_unit_templates(
self.sorting_analyzer_dense,
sparsity=self.sparsity_best,
Expand All @@ -225,15 +233,13 @@ def test_plot_unit_templates(self):
**self.backend_kwargs[backend],
)
# test different shadings
print("Sparse")
sw.plot_unit_templates(
self.sorting_analyzer_sparse,
unit_ids=unit_ids,
templates_percentile_shading=None,
backend=backend,
**self.backend_kwargs[backend],
)
print("Sparse2")
sw.plot_unit_templates(
self.sorting_analyzer_sparse,
unit_ids=unit_ids,
Expand All @@ -242,8 +248,6 @@ def test_plot_unit_templates(self):
backend=backend,
**self.backend_kwargs[backend],
)
# test different shadings
print("Sparse3")
sw.plot_unit_templates(
self.sorting_analyzer_sparse,
unit_ids=unit_ids,
Expand All @@ -252,15 +256,14 @@ def test_plot_unit_templates(self):
shade_templates=False,
**self.backend_kwargs[backend],
)
print("Sparse4")
sw.plot_unit_templates(
self.sorting_analyzer_sparse,
unit_ids=unit_ids,
templates_percentile_shading=0.1,
backend=backend,
**self.backend_kwargs[backend],
)
print("Extra sparsity")
# extra sparsity
sw.plot_unit_templates(
self.sorting_analyzer_sparse,
sparsity=self.sparsity_strict,
Expand All @@ -269,8 +272,18 @@ def test_plot_unit_templates(self):
backend=backend,
**self.backend_kwargs[backend],
)
# channel ids
sw.plot_unit_templates(
self.sorting_analyzer_sparse,
channel_ids=self.sorting_analyzer_sparse.channel_ids[::3],
unit_ids=unit_ids,
templates_percentile_shading=[1, 10, 90, 99],
backend=backend,
**self.backend_kwargs[backend],
)

# test "larger" sparsity
with self.assertRaises(AssertionError):
with self.assertWarns(UserWarning):
sw.plot_unit_templates(
self.sorting_analyzer_sparse,
sparsity=self.sparsity_large,
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/widgets/unit_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ def plot_sortingview(self, data_plot, **backend_kwargs):
assert len(dp.templates_shading) <= 4, "Only 2 ans 4 templates shading are supported in sortingview"

# ensure serializable for sortingview
unit_id_to_channel_ids = dp.sparsity.unit_id_to_channel_ids
unit_id_to_channel_indices = dp.sparsity.unit_id_to_channel_indices
unit_id_to_channel_ids = dp.final_sparsity.unit_id_to_channel_ids
unit_id_to_channel_indices = dp.final_sparsity.unit_id_to_channel_indices

unit_ids, channel_ids = make_serializable(dp.unit_ids, dp.channel_ids)

Expand Down
140 changes: 80 additions & 60 deletions src/spikeinterface/widgets/unit_waveforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,38 +119,50 @@ def __init__(

if unit_ids is None:
unit_ids = sorting_analyzer_or_templates.unit_ids
if channel_ids is None:
channel_ids = sorting_analyzer_or_templates.channel_ids
if unit_colors is None:
unit_colors = get_unit_colors(sorting_analyzer_or_templates)

channel_indices = [list(sorting_analyzer_or_templates.channel_ids).index(ch) for ch in channel_ids]
channel_locations = sorting_analyzer_or_templates.get_channel_locations()[channel_indices]
extra_sparsity = False
if sorting_analyzer_or_templates.sparsity is not None:
if sparsity is None:
sparsity = sorting_analyzer_or_templates.sparsity
else:
# assert provided sparsity is a subset of waveform sparsity
combined_mask = np.logical_or(sorting_analyzer_or_templates.sparsity.mask, sparsity.mask)
assert np.all(np.sum(combined_mask, 1) - np.sum(sorting_analyzer_or_templates.sparsity.mask, 1) == 0), (
"The provided 'sparsity' needs to include only the sparse channels "
"used to extract waveforms (for example, by using a smaller 'radius_um')."
)
extra_sparsity = True
else:
if sparsity is None:
# in this case, we construct a dense sparsity
unit_id_to_channel_ids = {
u: sorting_analyzer_or_templates.channel_ids for u in sorting_analyzer_or_templates.unit_ids
}
sparsity = ChannelSparsity.from_unit_id_to_channel_ids(
unit_id_to_channel_ids=unit_id_to_channel_ids,
unit_ids=sorting_analyzer_or_templates.unit_ids,
channel_ids=sorting_analyzer_or_templates.channel_ids,
)
else:
assert isinstance(sparsity, ChannelSparsity), "'sparsity' should be a ChannelSparsity object!"
channel_locations = sorting_analyzer_or_templates.get_channel_locations()
extra_sparsity = None
# handle sparsity
sparsity_mismatch_warning = (
"The provided 'sparsity' includes additional channels not in the analyzer sparsity. "
"These extra channels will be plotted as flat lines."
)
analyzer_sparsity = sorting_analyzer_or_templates.sparsity
if channel_ids is not None:
assert sparsity is None, "If 'channel_ids' is provided, 'sparsity' should be None!"
channel_mask = np.tile(
np.isin(sorting_analyzer_or_templates.channel_ids, channel_ids),
(len(sorting_analyzer_or_templates.unit_ids), 1),
)
extra_sparsity = ChannelSparsity(
mask=channel_mask,
channel_ids=sorting_analyzer_or_templates.channel_ids,
unit_ids=sorting_analyzer_or_templates.unit_ids,
)
elif sparsity is not None:
extra_sparsity = sparsity

if channel_ids is None:
channel_ids = sorting_analyzer_or_templates.channel_ids

# assert provided sparsity is a subset of waveform sparsity
if extra_sparsity is not None and analyzer_sparsity is not None:
combined_mask = np.logical_or(analyzer_sparsity.mask, extra_sparsity.mask)
if not np.all(np.sum(combined_mask, 1) - np.sum(analyzer_sparsity.mask, 1) == 0):
warn(sparsity_mismatch_warning)

final_sparsity = extra_sparsity if extra_sparsity is not None else analyzer_sparsity
if final_sparsity is None:
final_sparsity = ChannelSparsity(
mask=np.ones(
(len(sorting_analyzer_or_templates.unit_ids), len(sorting_analyzer_or_templates.channel_ids)),
dtype=bool,
),
unit_ids=sorting_analyzer_or_templates.unit_ids,
channel_ids=sorting_analyzer_or_templates.channel_ids,
)

# get templates
if isinstance(sorting_analyzer_or_templates, Templates):
Expand All @@ -174,42 +186,23 @@ def __init__(
templates_percentile_shading = None
templates_shading = self._get_template_shadings(unit_ids, templates_percentile_shading)

wfs_by_ids = {}
if plot_waveforms:
# this must be a sorting_analyzer
wf_ext = sorting_analyzer_or_templates.get_extension("waveforms")
if wf_ext is None:
raise ValueError("plot_waveforms() needs the extension 'waveforms'")
for unit_id in unit_ids:
unit_index = list(sorting_analyzer_or_templates.unit_ids).index(unit_id)
if not extra_sparsity:
if sorting_analyzer_or_templates.is_sparse():
# wfs = we.get_waveforms(unit_id)
wfs = wf_ext.get_waveforms_one_unit(unit_id, force_dense=False)
else:
# wfs = we.get_waveforms(unit_id, sparsity=sparsity)
wfs = wf_ext.get_waveforms_one_unit(unit_id)
wfs = wfs[:, :, sparsity.mask[unit_index]]
else:
# in this case we have to slice the waveform sparsity based on the extra sparsity
# first get the sparse waveforms
# wfs = we.get_waveforms(unit_id)
wfs = wf_ext.get_waveforms_one_unit(unit_id, force_dense=False)
# find additional slice to apply to sparse waveforms
(wfs_sparse_indices,) = np.nonzero(sorting_analyzer_or_templates.sparsity.mask[unit_index])
(extra_sparse_indices,) = np.nonzero(sparsity.mask[unit_index])
(extra_slice,) = np.nonzero(np.isin(wfs_sparse_indices, extra_sparse_indices))
# apply extra sparsity
wfs = wfs[:, :, extra_slice]
wfs_by_ids[unit_id] = wfs
wfs_by_ids = self._get_wfs_by_ids(sorting_analyzer_or_templates, unit_ids, extra_sparsity=extra_sparsity)
else:
wfs_by_ids = None

plot_data = dict(
sorting_analyzer_or_templates=sorting_analyzer_or_templates,
sampling_frequency=sorting_analyzer_or_templates.sampling_frequency,
nbefore=nbefore,
unit_ids=unit_ids,
channel_ids=channel_ids,
sparsity=sparsity,
final_sparsity=final_sparsity,
extra_sparsity=extra_sparsity,
unit_colors=unit_colors,
channel_locations=channel_locations,
scale=scale,
Expand Down Expand Up @@ -270,7 +263,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
ax = self.axes.flatten()[i]
color = dp.unit_colors[unit_id]

chan_inds = dp.sparsity.unit_id_to_channel_indices[unit_id]
chan_inds = dp.final_sparsity.unit_id_to_channel_indices[unit_id]
xvectors_flat = xvectors[:, chan_inds].T.flatten()

# plot waveforms
Expand Down Expand Up @@ -502,6 +495,32 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):
if backend_kwargs["display"]:
display(self.widget)

def _get_wfs_by_ids(self, sorting_analyzer, unit_ids, extra_sparsity):
wfs_by_ids = {}
wf_ext = sorting_analyzer.get_extension("waveforms")
for unit_id in unit_ids:
unit_index = list(sorting_analyzer.unit_ids).index(unit_id)
if extra_sparsity is None:
wfs = wf_ext.get_waveforms_one_unit(unit_id, force_dense=False)
else:
# in this case we have to construct waveforms based on the extra sparsity and add the
# sparse waveforms on the valid channels
if sorting_analyzer.is_sparse():
original_mask = sorting_analyzer.sparsity.mask[unit_index]
else:
original_mask = np.ones(len(sorting_analyzer.channel_ids), dtype=bool)
wfs_orig = wf_ext.get_waveforms_one_unit(unit_id, force_dense=False)
wfs = np.zeros(
(wfs_orig.shape[0], wfs_orig.shape[1], extra_sparsity.mask[unit_index].sum()), dtype=wfs_orig.dtype
)
# fill in the existing waveforms channels
valid_wfs_indices = extra_sparsity.mask[unit_index][original_mask]
valid_extra_indices = original_mask[extra_sparsity.mask[unit_index]]
wfs[:, :, valid_extra_indices] = wfs_orig[:, :, valid_wfs_indices]

wfs_by_ids[unit_id] = wfs
return wfs_by_ids

def _get_template_shadings(self, unit_ids, templates_percentile_shading):
templates = self.templates_ext.get_templates(unit_ids=unit_ids, operator="average")

Expand Down Expand Up @@ -538,6 +557,8 @@ def _update_plot(self, change):
hide_axis = self.hide_axis_button.value
do_shading = self.template_shading_button.value

data_plot = self.next_data_plot

if self.sorting_analyzer is not None:
templates = self.templates_ext.get_templates(unit_ids=unit_ids, operator="average")
templates_shadings = self._get_template_shadings(unit_ids, data_plot["templates_percentile_shading"])
Expand All @@ -549,7 +570,6 @@ def _update_plot(self, change):
channel_locations = self.templates.get_channel_locations()

# matplotlib next_data_plot dict update at each call
data_plot = self.next_data_plot
data_plot["unit_ids"] = unit_ids
data_plot["templates"] = templates
data_plot["templates_shading"] = templates_shadings
Expand All @@ -564,10 +584,10 @@ def _update_plot(self, change):
data_plot["scalebar"] = self.scalebar.value

if data_plot["plot_waveforms"]:
wf_ext = self.sorting_analyzer.get_extension("waveforms")
data_plot["wfs_by_ids"] = {
unit_id: wf_ext.get_waveforms_one_unit(unit_id, force_dense=False) for unit_id in unit_ids
}
wfs_by_ids = self._get_wfs_by_ids(
self.sorting_analyzer, unit_ids, extra_sparsity=data_plot["extra_sparsity"]
)
data_plot["wfs_by_ids"] = wfs_by_ids

# TODO option for plot_legend
backend_kwargs = {}
Expand Down Expand Up @@ -611,7 +631,7 @@ def _plot_probe(self, ax, channel_locations, unit_ids):

# TODO this could be done with probeinterface plotting plotting tools!!
for unit in unit_ids:
channel_inds = self.data_plot["sparsity"].unit_id_to_channel_indices[unit]
channel_inds = self.data_plot["final_sparsity"].unit_id_to_channel_indices[unit]
ax.plot(
channel_locations[channel_inds, 0],
channel_locations[channel_inds, 1],
Expand Down
1 change: 1 addition & 0 deletions src/spikeinterface/widgets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def array_to_image(
output_image : 3D numpy array
"""
import matplotlib.pyplot as plt

from scipy.ndimage import zoom

Expand Down

0 comments on commit ab59a93

Please sign in to comment.