From e51bb75f226c7c2be97c4a6ceeae460a7c610efe Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 6 Oct 2023 09:25:35 +0200 Subject: [PATCH] Fix order_channel_by_depth in ipywidgets Fix order_channel_by_depth when channel_ids is given. --- src/spikeinterface/widgets/traces.py | 58 +++++++++++++++------------- 1 file changed, 32 insertions(+), 26 deletions(-) diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index 2783b6a369..802f90c62a 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -88,6 +88,26 @@ def __init__( else: raise ValueError("plot_traces recording must be recording or dict or list") + if "location" in rec0.get_property_keys(): + channel_locations = rec0.get_channel_locations() + else: + channel_locations = None + + if order_channel_by_depth and channel_locations is not None: + from ..preprocessing import depth_order + rec0 = depth_order(rec0) + recordings = {k: depth_order(rec) for k, rec in recordings.items()} + + if channel_ids is not None: + # ensure that channel_ids are in the good order + channel_ids_ = list(rec0.channel_ids) + order = np.argsort([channel_ids_.index(c) for c in channel_ids]) + channel_ids = list(np.array(channel_ids)[order]) + + if channel_ids is None: + channel_ids = rec0.channel_ids + + layer_keys = list(recordings.keys()) if segment_index is None: @@ -95,19 +115,6 @@ def __init__( raise ValueError("You must provide segment_index=...") segment_index = 0 - if channel_ids is None: - channel_ids = rec0.channel_ids - - if "location" in rec0.get_property_keys(): - channel_locations = rec0.get_channel_locations() - else: - channel_locations = None - - if order_channel_by_depth: - if channel_locations is not None: - order, _ = order_channels_by_depth(rec0, channel_ids) - else: - order = None fs = rec0.get_sampling_frequency() if time_range is None: @@ -124,7 +131,7 @@ def __init__( cmap = cmap times, list_traces, frame_range, channel_ids = _get_trace_list( - recordings, channel_ids, time_range, segment_index, order, return_scaled + recordings, channel_ids, time_range, segment_index, return_scaled=return_scaled ) # stat for auto scaling done on the first layer @@ -202,7 +209,6 @@ def __init__( show_channel_ids=show_channel_ids, add_legend=add_legend, order_channel_by_depth=order_channel_by_depth, - order=order, tile_size=tile_size, num_timepoints_per_row=int(seconds_per_row * fs), return_scaled=return_scaled, @@ -337,7 +343,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): ) self.scaler = ScaleWidget() self.channel_selector = ChannelSelector(self.rec0.channel_ids) - self.channel_selector.value = data_plot["channel_ids"] + self.channel_selector.value = list(data_plot["channel_ids"]) left_sidebar = W.VBox( children=[ @@ -400,17 +406,17 @@ def _mode_changed(self, change=None): def _retrieve_traces(self, change=None): channel_ids = np.array(self.channel_selector.value) - if self.data_plot["order_channel_by_depth"]: - order, _ = order_channels_by_depth(self.rec0, channel_ids) - else: - order = None + # if self.data_plot["order_channel_by_depth"]: + # order, _ = order_channels_by_depth(self.rec0, channel_ids) + # else: + # order = None start_frame, end_frame, segment_index = self.time_slider.value time_range = np.array([start_frame, end_frame]) / self.rec0.sampling_frequency self._selected_recordings = {k: self.recordings[k] for k in self._get_layers()} times, list_traces, frame_range, channel_ids = _get_trace_list( - self._selected_recordings, channel_ids, time_range, segment_index, order, self.return_scaled + self._selected_recordings, channel_ids, time_range, segment_index, return_scaled=self.return_scaled ) self._channel_ids = channel_ids @@ -525,7 +531,7 @@ def plot_ephyviewer(self, data_plot, **backend_kwargs): app.exec() -def _get_trace_list(recordings, channel_ids, time_range, segment_index, order=None, return_scaled=False): +def _get_trace_list(recordings, channel_ids, time_range, segment_index, return_scaled=False): # function also used in ipywidgets plotter k0 = list(recordings.keys())[0] rec0 = recordings[k0] @@ -552,11 +558,11 @@ def _get_trace_list(recordings, channel_ids, time_range, segment_index, order=No return_scaled=return_scaled, ) - if order is not None: - traces = traces[:, order] + # if order is not None: + # traces = traces[:, order] list_traces.append(traces) - if order is not None: - channel_ids = np.array(channel_ids)[order] + # if order is not None: + # channel_ids = np.array(channel_ids)[order] return times, list_traces, frame_range, channel_ids