From 3fcfd19528e9153897ce4aaa8f3b296069d95c8a Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Tue, 31 Oct 2023 15:18:09 -0400 Subject: [PATCH 1/2] typos in asserts + add assert --- src/spikeinterface/widgets/traces.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index 63fe4e8d8f..250809f61c 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -26,7 +26,6 @@ class TracesWidget(BaseWidget): List with start time and end time mode: "line" | "map" | "auto", default: "auto" Three possible modes - * "line": classical for low channel count * "map": for high channel count use color heat map * "auto": auto switch depending on the channel count ("line" if less than 64 channels, "map" otherwise) @@ -50,7 +49,7 @@ class TracesWidget(BaseWidget): seconds_per_row: float, default: 0.2 For "map" mode and sortingview backend, seconds to render in each row add_legend : bool, default: True - If True adds legend to figures, default: True + If True adds legend to figures """ def __init__( @@ -85,7 +84,10 @@ def __init__( recordings = {f"rec{i}": rec for i, rec in enumerate(recording)} rec0 = recordings[0] else: - raise ValueError("plot_traces recording must be recording or dict or list") + raise ValueError( + "plot_traces 'recording' must be recording or dict or list, recording type " + f"is currently of type {type(recording)}" + ) if rec0.has_channel_location(): channel_locations = rec0.get_channel_locations() @@ -111,7 +113,7 @@ def __init__( if segment_index is None: if rec0.get_num_segments() != 1: - raise ValueError("You must provide segment_index=...") + raise ValueError('You must provide "segment_index" for multisegment recordings.') segment_index = 0 fs = rec0.get_sampling_frequency() @@ -119,7 +121,7 @@ def __init__( time_range = (0, 1.0) time_range = np.array(time_range) - assert mode in ("auto", "line", "map"), "Mode must be in auto/line/map" + assert mode in ("auto", "line", "map"), 'Mode must be one of "auto","line", "map"' if mode == "auto": if len(channel_ids) <= 64: mode = "line" @@ -257,7 +259,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax.legend(loc="upper right") elif dp.mode == "map": - assert len(dp.list_traces) == 1, 'plot_traces with mode="map" do not support multi recording' + assert len(dp.list_traces) == 1, 'plot_traces with mode="map" does not support multi-recording' assert len(dp.clims) == 1 clim = list(dp.clims.values())[0] extent = (dp.time_range[0], dp.time_range[1], min_y, max_y) @@ -473,11 +475,11 @@ def plot_sortingview(self, data_plot, **backend_kwargs): try: import pyvips except ImportError: - raise ImportError("To use the timeseries in sorting view you need the pyvips package.") + raise ImportError("To use plot_traces in sortingview you need the pyvips package.") dp = to_attr(data_plot) - assert dp.mode == "map", 'sortingview plot_traces is only mode="map"' + assert dp.mode == "map", 'sortingview `plot_traces` can only have mode="map"' if not dp.order_channel_by_depth: warnings.warn( @@ -486,6 +488,9 @@ def plot_sortingview(self, data_plot, **backend_kwargs): tiled_layers = [] for layer_key, traces in zip(dp.layer_keys, dp.list_traces): + assert ( + traces.shape[1] != 1 + ), f'mode="map" only works with multichannel data, you currently have {traces.shape[1]} channels' img = array_to_image( traces, clim=dp.clims[layer_key], @@ -499,7 +504,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs): self.view = vv.TiledImage(tile_size=dp.tile_size, layers=tiled_layers) - # timeseries currently doesn't display on the jupyter backend + # traces currently doesn't display on the jupyter backend backend_kwargs["display"] = False self.url = handle_display_and_url(self, self.view, **backend_kwargs) From fc2294a1d7c03c71d52f57b7376ebc768c149bd3 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Tue, 31 Oct 2023 16:42:38 -0400 Subject: [PATCH 2/2] add another assert --- src/spikeinterface/widgets/traces.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index 250809f61c..45ae9c91aa 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -85,7 +85,7 @@ def __init__( rec0 = recordings[0] else: raise ValueError( - "plot_traces 'recording' must be recording or dict or list, recording type " + "plot_traces 'recording' must be recording or dict or list, recording " f"is currently of type {type(recording)}" ) @@ -183,7 +183,9 @@ def __init__( if isinstance(clim, tuple): clims = {layer_key: clim for layer_key in layer_keys} elif isinstance(clim, dict): - assert all(layer_key in clim for layer_key in layer_keys), "" + assert all( + layer_key in clim for layer_key in layer_keys + ), f"all recordings must be a key in `clim` if `clim` is a dict. Provide keys {layer_keys} in clim" clims = clim else: raise TypeError(f"'clim' can be None, tuple, or dict! Unsupported type {type(clim)}") @@ -475,7 +477,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs): try: import pyvips except ImportError: - raise ImportError("To use plot_traces in sortingview you need the pyvips package.") + raise ImportError("To use `plot_traces()` in sortingview you need the pyvips package.") dp = to_attr(data_plot) @@ -488,9 +490,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs): tiled_layers = [] for layer_key, traces in zip(dp.layer_keys, dp.list_traces): - assert ( - traces.shape[1] != 1 - ), f'mode="map" only works with multichannel data, you currently have {traces.shape[1]} channels' + assert traces.shape[1] != 1, 'mode="map" only works with multichannel data' img = array_to_image( traces, clim=dp.clims[layer_key],