Skip to content

Commit

Permalink
Fix order_channel_by_depth in ipywidgets
Browse files Browse the repository at this point in the history
Fix order_channel_by_depth when channel_ids is given.
  • Loading branch information
samuelgarcia committed Oct 6, 2023
1 parent 3448e1e commit e51bb75
Showing 1 changed file with 32 additions and 26 deletions.
58 changes: 32 additions & 26 deletions src/spikeinterface/widgets/traces.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,26 +88,33 @@ def __init__(
else:
raise ValueError("plot_traces recording must be recording or dict or list")

if "location" in rec0.get_property_keys():
channel_locations = rec0.get_channel_locations()
else:
channel_locations = None

if order_channel_by_depth and channel_locations is not None:
from ..preprocessing import depth_order
rec0 = depth_order(rec0)
recordings = {k: depth_order(rec) for k, rec in recordings.items()}

if channel_ids is not None:
# ensure that channel_ids are in the good order
channel_ids_ = list(rec0.channel_ids)
order = np.argsort([channel_ids_.index(c) for c in channel_ids])
channel_ids = list(np.array(channel_ids)[order])

if channel_ids is None:
channel_ids = rec0.channel_ids


layer_keys = list(recordings.keys())

if segment_index is None:
if rec0.get_num_segments() != 1:
raise ValueError("You must provide segment_index=...")
segment_index = 0

if channel_ids is None:
channel_ids = rec0.channel_ids

if "location" in rec0.get_property_keys():
channel_locations = rec0.get_channel_locations()
else:
channel_locations = None

if order_channel_by_depth:
if channel_locations is not None:
order, _ = order_channels_by_depth(rec0, channel_ids)
else:
order = None

fs = rec0.get_sampling_frequency()
if time_range is None:
Expand All @@ -124,7 +131,7 @@ def __init__(
cmap = cmap

times, list_traces, frame_range, channel_ids = _get_trace_list(
recordings, channel_ids, time_range, segment_index, order, return_scaled
recordings, channel_ids, time_range, segment_index, return_scaled=return_scaled
)

# stat for auto scaling done on the first layer
Expand Down Expand Up @@ -202,7 +209,6 @@ def __init__(
show_channel_ids=show_channel_ids,
add_legend=add_legend,
order_channel_by_depth=order_channel_by_depth,
order=order,
tile_size=tile_size,
num_timepoints_per_row=int(seconds_per_row * fs),
return_scaled=return_scaled,
Expand Down Expand Up @@ -337,7 +343,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):
)
self.scaler = ScaleWidget()
self.channel_selector = ChannelSelector(self.rec0.channel_ids)
self.channel_selector.value = data_plot["channel_ids"]
self.channel_selector.value = list(data_plot["channel_ids"])

left_sidebar = W.VBox(
children=[
Expand Down Expand Up @@ -400,17 +406,17 @@ def _mode_changed(self, change=None):
def _retrieve_traces(self, change=None):
channel_ids = np.array(self.channel_selector.value)

if self.data_plot["order_channel_by_depth"]:
order, _ = order_channels_by_depth(self.rec0, channel_ids)
else:
order = None
# if self.data_plot["order_channel_by_depth"]:
# order, _ = order_channels_by_depth(self.rec0, channel_ids)
# else:
# order = None

start_frame, end_frame, segment_index = self.time_slider.value
time_range = np.array([start_frame, end_frame]) / self.rec0.sampling_frequency

self._selected_recordings = {k: self.recordings[k] for k in self._get_layers()}
times, list_traces, frame_range, channel_ids = _get_trace_list(
self._selected_recordings, channel_ids, time_range, segment_index, order, self.return_scaled
self._selected_recordings, channel_ids, time_range, segment_index, return_scaled=self.return_scaled
)

self._channel_ids = channel_ids
Expand Down Expand Up @@ -525,7 +531,7 @@ def plot_ephyviewer(self, data_plot, **backend_kwargs):
app.exec()


def _get_trace_list(recordings, channel_ids, time_range, segment_index, order=None, return_scaled=False):
def _get_trace_list(recordings, channel_ids, time_range, segment_index, return_scaled=False):
# function also used in ipywidgets plotter
k0 = list(recordings.keys())[0]
rec0 = recordings[k0]
Expand All @@ -552,11 +558,11 @@ def _get_trace_list(recordings, channel_ids, time_range, segment_index, order=No
return_scaled=return_scaled,
)

if order is not None:
traces = traces[:, order]
# if order is not None:
# traces = traces[:, order]
list_traces.append(traces)

if order is not None:
channel_ids = np.array(channel_ids)[order]
# if order is not None:
# channel_ids = np.array(channel_ids)[order]

return times, list_traces, frame_range, channel_ids

0 comments on commit e51bb75

Please sign in to comment.