Skip to content

Commit

Permalink
Fix plot_traces with ipywidgets when channel_ids is not None
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelgarcia committed Oct 6, 2023
1 parent cdc1ccb commit 3448e1e
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 6 deletions.
10 changes: 6 additions & 4 deletions src/spikeinterface/widgets/traces.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,10 @@ def __init__(

# colors is a nested dict by layer and channels
# lets first create black for all channels and layer
# all color are generated for ipywidgets
colors = {}
for k in layer_keys:
colors[k] = {chan_id: "k" for chan_id in channel_ids}
colors[k] = {chan_id: "k" for chan_id in rec0.channel_ids}

if color_groups:
channel_groups = rec0.get_channel_groups(channel_ids=channel_ids)
Expand All @@ -149,7 +150,7 @@ def __init__(
group_colors = get_some_colors(groups, color_engine="auto")

channel_colors = {}
for i, chan_id in enumerate(channel_ids):
for i, chan_id in enumerate(rec0.channel_ids):
group = channel_groups[i]
channel_colors[chan_id] = group_colors[group]

Expand All @@ -159,12 +160,12 @@ def __init__(
elif color is not None:
# old behavior one color for all channel
# if multi layer then black for all
colors[layer_keys[0]] = {chan_id: color for chan_id in channel_ids}
colors[layer_keys[0]] = {chan_id: color for chan_id in rec0.channel_ids}
elif color is None and len(recordings) > 1:
# several layer
layer_colors = get_some_colors(layer_keys)
for k in layer_keys:
colors[k] = {chan_id: layer_colors[k] for chan_id in channel_ids}
colors[k] = {chan_id: layer_colors[k] for chan_id in rec0.channel_ids}
else:
# color is None unique layer : all channels black
pass
Expand Down Expand Up @@ -336,6 +337,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):
)
self.scaler = ScaleWidget()
self.channel_selector = ChannelSelector(self.rec0.channel_ids)
self.channel_selector.value = data_plot["channel_ids"]

left_sidebar = W.VBox(
children=[
Expand Down
16 changes: 14 additions & 2 deletions src/spikeinterface/widgets/utils_ipywidgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,7 @@ def __init__(self, channel_ids, **kwargs):
self.slider.observe(self.on_slider_changed, names=["value"], type="change")
self.selector.observe(self.on_selector_changed, names=["value"], type="change")

# TODO external value change
# self.observe(self.value_changed, names=['value'], type="change")
self.observe(self.value_changed, names=['value'], type="change")

def on_slider_changed(self, change=None):
i0, i1 = self.slider.value
Expand All @@ -259,6 +258,19 @@ def on_selector_changed(self, change=None):
self.slider.observe(self.on_slider_changed, names=["value"], type="change")

self.value = channel_ids

def value_changed(self, change=None):
self.selector.unobserve(self.on_selector_changed, names=["value"], type="change")
self.selector.value = change["new"]
self.selector.observe(self.on_selector_changed, names=["value"], type="change")

channel_ids = self.selector.value
self.slider.unobserve(self.on_slider_changed, names=["value"], type="change")
i0 = self.channel_ids.index(channel_ids[0])
i1 = self.channel_ids.index(channel_ids[-1]) + 1
self.slider.value = (i0, i1)
self.slider.observe(self.on_slider_changed, names=["value"], type="change")



class ScaleWidget(W.VBox):
Expand Down

0 comments on commit 3448e1e

Please sign in to comment.