Skip to content

Commit

Permalink
Simplify sparsity handling in plot waveforms/templates and fix plot_t…
Browse files Browse the repository at this point in the history
…races sortingview
  • Loading branch information
alejoe91 committed Jul 15, 2024
1 parent 63a31c0 commit 600f25e
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 44 deletions.
4 changes: 2 additions & 2 deletions src/spikeinterface/widgets/unit_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ def plot_sortingview(self, data_plot, **backend_kwargs):
assert len(dp.templates_shading) <= 4, "Only 2 ans 4 templates shading are supported in sortingview"

# ensure serializable for sortingview
unit_id_to_channel_ids = dp.sparsity.unit_id_to_channel_ids
unit_id_to_channel_indices = dp.sparsity.unit_id_to_channel_indices
unit_id_to_channel_ids = dp.final_sparsity.unit_id_to_channel_ids
unit_id_to_channel_indices = dp.final_sparsity.unit_id_to_channel_indices

unit_ids, channel_ids = make_serializable(dp.unit_ids, dp.channel_ids)

Expand Down
77 changes: 35 additions & 42 deletions src/spikeinterface/widgets/unit_waveforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,51 +123,47 @@ def __init__(
unit_colors = get_unit_colors(sorting_analyzer_or_templates)

channel_locations = sorting_analyzer_or_templates.get_channel_locations()
extra_sparsity = False
extra_sparsity = None
# handle sparsity
sparsity_mismatch_warning = (
"The provided 'sparsity' includes additional channels not in the analyzer sparsity. "
"These extra channels will be plotted as flat lines."
)
analyzer_sparsity = sorting_analyzer_or_templates.sparsity
if channel_ids is not None:
assert sparsity is None, "If 'channel_ids' is provided, 'sparsity' should be None!"
channel_mask = np.tile(
np.isin(sorting_analyzer_or_templates.channel_ids, channel_ids),
(len(sorting_analyzer_or_templates.unit_ids), 1),
)
sparsity = ChannelSparsity(
extra_sparsity = ChannelSparsity(
mask=channel_mask,
channel_ids=sorting_analyzer_or_templates.channel_ids,
unit_ids=sorting_analyzer_or_templates.unit_ids,
)
extra_sparsity = True
elif analyzer_sparsity is not None:
if sparsity is None:
sparsity = analyzer_sparsity
else:
extra_sparsity = True
else:
if sparsity is None:
unit_id_to_channel_ids = {
u: sorting_analyzer_or_templates.channel_ids for u in sorting_analyzer_or_templates.unit_ids
}
sparsity = ChannelSparsity.from_unit_id_to_channel_ids(
unit_id_to_channel_ids=unit_id_to_channel_ids,
unit_ids=sorting_analyzer_or_templates.unit_ids,
channel_ids=sorting_analyzer_or_templates.channel_ids,
)
else:
assert isinstance(sparsity, ChannelSparsity), "'sparsity' should be a ChannelSparsity object!"
elif sparsity is not None:
extra_sparsity = sparsity

if channel_ids is None:
channel_ids = sorting_analyzer_or_templates.channel_ids

# assert provided sparsity is a subset of waveform sparsity
if extra_sparsity:
combined_mask = np.logical_or(analyzer_sparsity.mask, sparsity.mask)
if not np.all(np.sum(combined_mask, 1) - np.sum(sorting_analyzer_or_templates.sparsity.mask, 1) == 0):
if extra_sparsity is not None and analyzer_sparsity is not None:
combined_mask = np.logical_or(analyzer_sparsity.mask, extra_sparsity.mask)
if not np.all(np.sum(combined_mask, 1) - np.sum(analyzer_sparsity.mask, 1) == 0):
warn(sparsity_mismatch_warning)

final_sparsity = extra_sparsity if extra_sparsity is not None else analyzer_sparsity
if final_sparsity is None:
final_sparsity = ChannelSparsity(
mask=np.ones(
(len(sorting_analyzer_or_templates.unit_ids), len(sorting_analyzer_or_templates.channel_ids)),
dtype=bool,
),
unit_ids=sorting_analyzer_or_templates.unit_ids,
channel_ids=sorting_analyzer_or_templates.channel_ids,
)

# get templates
if isinstance(sorting_analyzer_or_templates, Templates):
templates = sorting_analyzer_or_templates.templates_array
Expand Down Expand Up @@ -195,9 +191,7 @@ def __init__(
wf_ext = sorting_analyzer_or_templates.get_extension("waveforms")
if wf_ext is None:
raise ValueError("plot_waveforms() needs the extension 'waveforms'")
wfs_by_ids = self._get_wfs_by_ids(
sorting_analyzer_or_templates, unit_ids, sparsity, extra_sparsity=extra_sparsity
)
wfs_by_ids = self._get_wfs_by_ids(sorting_analyzer_or_templates, unit_ids, extra_sparsity=extra_sparsity)
else:
wfs_by_ids = None

Expand All @@ -207,7 +201,8 @@ def __init__(
nbefore=nbefore,
unit_ids=unit_ids,
channel_ids=channel_ids,
sparsity=sparsity,
final_sparsity=final_sparsity,
extra_sparsity=extra_sparsity,
unit_colors=unit_colors,
channel_locations=channel_locations,
scale=scale,
Expand All @@ -234,7 +229,6 @@ def __init__(
alpha_templates=alpha_templates,
hide_unit_selector=hide_unit_selector,
plot_legend=plot_legend,
extra_sparsity=extra_sparsity,
)
BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs)

Expand Down Expand Up @@ -269,7 +263,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
ax = self.axes.flatten()[i]
color = dp.unit_colors[unit_id]

chan_inds = dp.sparsity.unit_id_to_channel_indices[unit_id]
chan_inds = dp.final_sparsity.unit_id_to_channel_indices[unit_id]
xvectors_flat = xvectors[:, chan_inds].T.flatten()

# plot waveforms
Expand Down Expand Up @@ -501,28 +495,27 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):
if backend_kwargs["display"]:
display(self.widget)

def _get_wfs_by_ids(self, sorting_analyzer, unit_ids, sparsity, extra_sparsity=False):
def _get_wfs_by_ids(self, sorting_analyzer, unit_ids, extra_sparsity):
wfs_by_ids = {}
wf_ext = sorting_analyzer.get_extension("waveforms")
for unit_id in unit_ids:
unit_index = list(sorting_analyzer.unit_ids).index(unit_id)
if not extra_sparsity:
# get waveforms with default sparsity
if sorting_analyzer.is_sparse():
wfs = wf_ext.get_waveforms_one_unit(unit_id, force_dense=False)
else:
wfs = wf_ext.get_waveforms_one_unit(unit_id)
wfs = wfs[:, :, sparsity.mask[unit_index]]
if extra_sparsity is None:
wfs = wf_ext.get_waveforms_one_unit(unit_id, force_dense=False)
else:
# in this case we have to construct waveforms based on the extra sparsity and add the
# sparse waveforms on the valid channels
if sorting_analyzer.is_sparse():
original_mask = sorting_analyzer.sparsity.mask[unit_index]
else:
original_mask = np.ones(len(sorting_analyzer.channel_ids), dtype=bool)
wfs_orig = wf_ext.get_waveforms_one_unit(unit_id, force_dense=False)
wfs = np.zeros(
(wfs_orig.shape[0], wfs_orig.shape[1], sparsity.mask[unit_index].sum()), dtype=wfs_orig.dtype
(wfs_orig.shape[0], wfs_orig.shape[1], extra_sparsity.mask[unit_index].sum()), dtype=wfs_orig.dtype
)
# fill in the existing waveforms channels
valid_wfs_indices = sparsity.mask[unit_index][sorting_analyzer.sparsity.mask[unit_index]]
valid_extra_indices = sorting_analyzer.sparsity.mask[unit_index][sparsity.mask[unit_index]]
valid_wfs_indices = extra_sparsity.mask[unit_index][original_mask]
valid_extra_indices = original_mask[extra_sparsity.mask[unit_index]]
wfs[:, :, valid_extra_indices] = wfs_orig[:, :, valid_wfs_indices]

wfs_by_ids[unit_id] = wfs
Expand Down Expand Up @@ -592,7 +585,7 @@ def _update_plot(self, change):

if data_plot["plot_waveforms"]:
wfs_by_ids = self._get_wfs_by_ids(
self.sorting_analyzer, unit_ids, data_plot["sparsity"], extra_sparsity=data_plot["extra_sparsity"]
self.sorting_analyzer, unit_ids, extra_sparsity=data_plot["extra_sparsity"]
)
data_plot["wfs_by_ids"] = wfs_by_ids

Expand Down Expand Up @@ -638,7 +631,7 @@ def _plot_probe(self, ax, channel_locations, unit_ids):

# TODO this could be done with probeinterface plotting plotting tools!!
for unit in unit_ids:
channel_inds = self.data_plot["sparsity"].unit_id_to_channel_indices[unit]
channel_inds = self.data_plot["final_sparsity"].unit_id_to_channel_indices[unit]
ax.plot(
channel_locations[channel_inds, 0],
channel_locations[channel_inds, 1],
Expand Down
1 change: 1 addition & 0 deletions src/spikeinterface/widgets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def array_to_image(
output_image : 3D numpy array
"""
import matplotlib.pyplot as plt

from scipy.ndimage import zoom

Expand Down

0 comments on commit 600f25e

Please sign in to comment.