Skip to content

Commit

Permalink
Merge pull request #2149 from zm711/plot_traces
Browse files Browse the repository at this point in the history
Add assert error in the case of unichannel data for sortingview backend + minor clarifications
  • Loading branch information
alejoe91 authored Nov 2, 2023
2 parents 9df709d + fc2294a commit 5b2d5bf
Showing 1 changed file with 15 additions and 10 deletions.
25 changes: 15 additions & 10 deletions src/spikeinterface/widgets/traces.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ class TracesWidget(BaseWidget):
List with start time and end time
mode: "line" | "map" | "auto", default: "auto"
Three possible modes
* "line": classical for low channel count
* "map": for high channel count use color heat map
* "auto": auto switch depending on the channel count ("line" if less than 64 channels, "map" otherwise)
Expand All @@ -50,7 +49,7 @@ class TracesWidget(BaseWidget):
seconds_per_row: float, default: 0.2
For "map" mode and sortingview backend, seconds to render in each row
add_legend : bool, default: True
If True adds legend to figures, default: True
If True adds legend to figures
"""

def __init__(
Expand Down Expand Up @@ -85,7 +84,10 @@ def __init__(
recordings = {f"rec{i}": rec for i, rec in enumerate(recording)}
rec0 = recordings[0]
else:
raise ValueError("plot_traces recording must be recording or dict or list")
raise ValueError(
"plot_traces 'recording' must be recording or dict or list, recording "
f"is currently of type {type(recording)}"
)

if rec0.has_channel_location():
channel_locations = rec0.get_channel_locations()
Expand All @@ -111,15 +113,15 @@ def __init__(

if segment_index is None:
if rec0.get_num_segments() != 1:
raise ValueError("You must provide segment_index=...")
raise ValueError('You must provide "segment_index" for multisegment recordings.')
segment_index = 0

fs = rec0.get_sampling_frequency()
if time_range is None:
time_range = (0, 1.0)
time_range = np.array(time_range)

assert mode in ("auto", "line", "map"), "Mode must be in auto/line/map"
assert mode in ("auto", "line", "map"), 'Mode must be one of "auto","line", "map"'
if mode == "auto":
if len(channel_ids) <= 64:
mode = "line"
Expand Down Expand Up @@ -181,7 +183,9 @@ def __init__(
if isinstance(clim, tuple):
clims = {layer_key: clim for layer_key in layer_keys}
elif isinstance(clim, dict):
assert all(layer_key in clim for layer_key in layer_keys), ""
assert all(
layer_key in clim for layer_key in layer_keys
), f"all recordings must be a key in `clim` if `clim` is a dict. Provide keys {layer_keys} in clim"
clims = clim
else:
raise TypeError(f"'clim' can be None, tuple, or dict! Unsupported type {type(clim)}")
Expand Down Expand Up @@ -257,7 +261,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
ax.legend(loc="upper right")

elif dp.mode == "map":
assert len(dp.list_traces) == 1, 'plot_traces with mode="map" do not support multi recording'
assert len(dp.list_traces) == 1, 'plot_traces with mode="map" does not support multi-recording'
assert len(dp.clims) == 1
clim = list(dp.clims.values())[0]
extent = (dp.time_range[0], dp.time_range[1], min_y, max_y)
Expand Down Expand Up @@ -473,11 +477,11 @@ def plot_sortingview(self, data_plot, **backend_kwargs):
try:
import pyvips
except ImportError:
raise ImportError("To use the timeseries in sorting view you need the pyvips package.")
raise ImportError("To use `plot_traces()` in sortingview you need the pyvips package.")

dp = to_attr(data_plot)

assert dp.mode == "map", 'sortingview plot_traces is only mode="map"'
assert dp.mode == "map", 'sortingview `plot_traces` can only have mode="map"'

if not dp.order_channel_by_depth:
warnings.warn(
Expand All @@ -486,6 +490,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs):

tiled_layers = []
for layer_key, traces in zip(dp.layer_keys, dp.list_traces):
assert traces.shape[1] != 1, 'mode="map" only works with multichannel data'
img = array_to_image(
traces,
clim=dp.clims[layer_key],
Expand All @@ -499,7 +504,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs):

self.view = vv.TiledImage(tile_size=dp.tile_size, layers=tiled_layers)

# timeseries currently doesn't display on the jupyter backend
# traces currently doesn't display on the jupyter backend
backend_kwargs["display"] = False

self.url = handle_display_and_url(self, self.view, **backend_kwargs)
Expand Down

0 comments on commit 5b2d5bf

Please sign in to comment.