From e964731b33401db1757ce813d2078c00a36dcf34 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 21 Sep 2023 16:36:31 +0200 Subject: [PATCH 1/8] Start refactor ipywidgets plot_traces --- src/spikeinterface/widgets/traces.py | 29 +- .../widgets/utils_ipywidgets.py | 251 ++++++++++++++++-- 2 files changed, 254 insertions(+), 26 deletions(-) diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index 7bb2126744..c6e36387f8 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -276,11 +276,16 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt import ipywidgets.widgets as widgets from IPython.display import display + import ipywidgets.widgets as W from .utils_ipywidgets import ( check_ipywidget_backend, make_timeseries_controller, make_channel_controller, make_scale_controller, + + TimeSlider, + ScaleWidget, + ) check_ipywidget_backend() @@ -308,6 +313,8 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): t_start = 0.0 t_stop = rec0.get_num_samples(segment_index=0) / rec0.get_sampling_frequency() + + ts_widget, ts_controller = make_timeseries_controller( t_start, t_stop, @@ -319,6 +326,22 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): width_cm, ) + # some widgets + self.time_slider = TimeSlider( + durations=[rec0.get_duration(s) for s in range(rec0.get_num_segments())], + sampling_frequency=rec0.sampling_frequency, + ) + self.layer_selector = W.Dropdown(description="layer", options=data_plot["layer_keys"], + layout=W.Layout(width="5cm"),) + self.mode_selector = W.Dropdown(options=["line", "map"], description="mode", value=data_plot["mode"], + layout=W.Layout(width="5cm"),) + self.scaler = ScaleWidget() + left_sidebar = W.VBox( + children=[self.layer_selector, self.mode_selector, self.scaler], + layout=W.Layout(width="5cm"), + ) + + ch_widget, ch_controller = make_channel_controller(rec0, width_cm=ratios[2] * width_cm, height_cm=height_cm) scale_widget, scale_controller = make_scale_controller(width_cm=ratios[0] * width_cm, height_cm=height_cm) @@ -346,8 +369,10 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.widget = widgets.AppLayout( center=self.figure.canvas, - footer=ts_widget, - left_sidebar=scale_widget, + # footer=ts_widget, + footer=self.time_slider, + # left_sidebar=scale_widget, + left_sidebar = left_sidebar, right_sidebar=ch_widget, pane_heights=[0, 6, 1], pane_widths=ratios, diff --git a/src/spikeinterface/widgets/utils_ipywidgets.py b/src/spikeinterface/widgets/utils_ipywidgets.py index a7c571d1f0..674a2d2cc7 100644 --- a/src/spikeinterface/widgets/utils_ipywidgets.py +++ b/src/spikeinterface/widgets/utils_ipywidgets.py @@ -1,4 +1,6 @@ -import ipywidgets.widgets as widgets +import ipywidgets.widgets as W +import traitlets + import numpy as np @@ -10,20 +12,20 @@ def check_ipywidget_backend(): def make_timeseries_controller(t_start, t_stop, layer_keys, num_segments, time_range, mode, all_layers, width_cm): - time_slider = widgets.FloatSlider( + time_slider = W.FloatSlider( orientation="horizontal", description="time:", value=time_range[0], min=t_start, max=t_stop, continuous_update=False, - layout=widgets.Layout(width=f"{width_cm}cm"), + layout=W.Layout(width=f"{width_cm}cm"), ) - layer_selector = widgets.Dropdown(description="layer", options=layer_keys) - segment_selector = widgets.Dropdown(description="segment", options=list(range(num_segments))) - window_sizer = widgets.BoundedFloatText(value=np.diff(time_range)[0], step=0.1, min=0.005, description="win (s)") - mode_selector = widgets.Dropdown(options=["line", "map"], description="mode", value=mode) - all_layers = widgets.Checkbox(description="plot all layers", value=all_layers) + layer_selector = W.Dropdown(description="layer", options=layer_keys) + segment_selector = W.Dropdown(description="segment", options=list(range(num_segments))) + window_sizer = W.BoundedFloatText(value=np.diff(time_range)[0], step=0.1, min=0.005, description="win (s)") + mode_selector = W.Dropdown(options=["line", "map"], description="mode", value=mode) + all_layers = W.Checkbox(description="plot all layers", value=all_layers) controller = { "layer_key": layer_selector, @@ -33,32 +35,32 @@ def make_timeseries_controller(t_start, t_stop, layer_keys, num_segments, time_r "mode": mode_selector, "all_layers": all_layers, } - widget = widgets.VBox( - [time_slider, widgets.HBox([all_layers, layer_selector, segment_selector, window_sizer, mode_selector])] + widget = W.VBox( + [time_slider, W.HBox([all_layers, layer_selector, segment_selector, window_sizer, mode_selector])] ) return widget, controller def make_unit_controller(unit_ids, all_unit_ids, width_cm, height_cm): - unit_label = widgets.Label(value="units:") + unit_label = W.Label(value="units:") - unit_selector = widgets.SelectMultiple( + unit_selector = W.SelectMultiple( options=all_unit_ids, value=list(unit_ids), disabled=False, - layout=widgets.Layout(width=f"{width_cm}cm", height=f"{height_cm}cm"), + layout=W.Layout(width=f"{width_cm}cm", height=f"{height_cm}cm"), ) controller = {"unit_ids": unit_selector} - widget = widgets.VBox([unit_label, unit_selector]) + widget = W.VBox([unit_label, unit_selector]) return widget, controller def make_channel_controller(recording, width_cm, height_cm): - channel_label = widgets.Label("channel indices:", layout=widgets.Layout(justify_content="center")) - channel_selector = widgets.IntRangeSlider( + channel_label = W.Label("channel indices:", layout=W.Layout(justify_content="center")) + channel_selector = W.IntRangeSlider( value=[0, recording.get_num_channels()], min=0, max=recording.get_num_channels(), @@ -68,37 +70,238 @@ def make_channel_controller(recording, width_cm, height_cm): orientation="vertical", readout=True, readout_format="d", - layout=widgets.Layout(width=f"{0.8 * width_cm}cm", height=f"{height_cm}cm"), + layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{height_cm}cm"), ) controller = {"channel_inds": channel_selector} - widget = widgets.VBox([channel_label, channel_selector]) + widget = W.VBox([channel_label, channel_selector]) return widget, controller def make_scale_controller(width_cm, height_cm): - scale_label = widgets.Label("Scale", layout=widgets.Layout(justify_content="center")) + scale_label = W.Label("Scale", layout=W.Layout(justify_content="center")) - plus_selector = widgets.Button( + plus_selector = W.Button( description="", disabled=False, button_style="", # 'success', 'info', 'warning', 'danger' or '' tooltip="Increase scale", icon="arrow-up", - layout=widgets.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), + layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), ) - minus_selector = widgets.Button( + minus_selector = W.Button( description="", disabled=False, button_style="", # 'success', 'info', 'warning', 'danger' or '' tooltip="Decrease scale", icon="arrow-down", - layout=widgets.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), + layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), ) controller = {"plus": plus_selector, "minus": minus_selector} - widget = widgets.VBox([scale_label, plus_selector, minus_selector]) + widget = W.VBox([scale_label, plus_selector, minus_selector]) return widget, controller + + + +class TimeSlider(W.HBox): + + position = traitlets.Tuple(traitlets.Int(), traitlets.Int(), traitlets.Int()) + + def __init__(self, durations, sampling_frequency, time_range=(0, 1.), **kwargs): + + + self.num_segments = len(durations) + self.frame_limits = [int(sampling_frequency * d) for d in durations] + self.sampling_frequency = sampling_frequency + start_frame = int(time_range[0] * sampling_frequency) + end_frame = int(time_range[1] * sampling_frequency) + + self.frame_range = (start_frame, end_frame) + + self.segment_index = 0 + self.position = (start_frame, end_frame, self.segment_index) + + + layout = W.Layout(align_items="center", width="1.5cm", height="100%") + but_left = W.Button(description='', disabled=False, button_style='', icon='arrow-left', layout=layout) + but_right = W.Button(description='', disabled=False, button_style='', icon='arrow-right', layout=layout) + + but_left.on_click(self.move_left) + but_right.on_click(self.move_right) + + self.move_size = W.Dropdown(options=['10 ms', '100 ms', '1 s', '10 s', '1 m', '30 m', '1 h',], # '6 h', '24 h' + value='1 s', + description='', + layout = W.Layout(width="2cm") + ) + + # DatetimePicker is only for ipywidget v8 (which is not working in vscode 2023-03) + self.time_label = W.Text(value=f'{time_range[0]}',description='', + disabled=False, layout=W.Layout(width='5.5cm')) + self.time_label.observe(self.time_label_changed, names='value', type="change") + + + self.slider = W.IntSlider( + orientation='horizontal', + # description='time:', + value=start_frame, + min=0, + max=self.frame_limits[self.segment_index], + readout=False, + continuous_update=False, + layout=W.Layout(width=f'70%') + ) + + self.slider.observe(self.slider_moved, names='value', type="change") + + delta_s = np.diff(self.frame_range) / sampling_frequency + + self.window_sizer = W.BoundedFloatText(value=delta_s, step=1, + min=0.01, max=30., + description='win (s)', + layout=W.Layout(width='auto') + # layout=W.Layout(width=f'10%') + ) + self.window_sizer.observe(self.win_size_changed, names='value', type="change") + + self.segment_selector = W.Dropdown(description="segment", options=list(range(self.num_segments))) + self.segment_selector.observe(self.segment_changed, names='value', type="change") + + super(W.HBox, self).__init__(children=[self.segment_selector, but_left, self.move_size, but_right, + self.slider, self.time_label, self.window_sizer], + layout=W.Layout(align_items="center", width="100%", height="100%"), + **kwargs) + + self.observe(self.position_changed, names=['position'], type="change") + + def position_changed(self, change=None): + + self.unobserve(self.position_changed, names=['position'], type="change") + + start, stop, seg_index = self.position + if seg_index < 0 or seg_index >= self.num_segments: + self.position = change['old'] + return + if start < 0 or stop < 0: + self.position = change['old'] + return + if start >= self.frame_limits[seg_index] or start > self.frame_limits[seg_index]: + self.position = change['old'] + return + + self.segment_selector.value = seg_index + self.update_time(new_frame=start, update_slider=True, update_label=True) + delta_s = (stop - start) / self.sampling_frequency + self.window_sizer.value = delta_s + + self.observe(self.position_changed, names=['position'], type="change") + + def update_time(self, new_frame=None, new_time=None, update_slider=False, update_label=False): + if new_frame is None and new_time is None: + start_frame = self.slider.value + elif new_frame is None: + start_frame = int(new_time * self.sampling_frequency) + else: + start_frame = new_frame + delta_s = self.window_sizer.value + end_frame = start_frame + int(delta_s * self.sampling_frequency) + + # clip + start_frame = max(0, start_frame) + end_frame = min(self.frame_limits[self.segment_index], end_frame) + + + start_time = start_frame / self.sampling_frequency + + if update_label: + self.time_label.unobserve(self.time_label_changed, names='value', type="change") + self.time_label.value = f'{start_time}' + self.time_label.observe(self.time_label_changed, names='value', type="change") + + if update_slider: + self.slider.unobserve(self.slider_moved, names='value', type="change") + self.slider.value = start_frame + self.slider.observe(self.slider_moved, names='value', type="change") + + self.frame_range = (start_frame, end_frame) + + def time_label_changed(self, change=None): + try: + new_time = float(self.time_label.value) + except: + new_time = None + if new_time is not None: + self.update_time(new_time=new_time, update_slider=True) + + + def win_size_changed(self, change=None): + self.update_time() + + def slider_moved(self, change=None): + new_frame = self.slider.value + self.update_time(new_frame=new_frame, update_label=True) + + def move(self, sign): + value, units = self.move_size.value.split(' ') + value = int(value) + delta_s = (sign * np.timedelta64(value, units)) / np.timedelta64(1, 's') + delta_sample = int(delta_s * self.sampling_frequency) + + new_frame = self.frame_range[0] + delta_sample + self.slider.value = new_frame + + def move_left(self, change=None): + self.move(-1) + + def move_right(self, change=None): + self.move(+1) + + def segment_changed(self, change=None): + self.segment_index = self.segment_selector.value + + self.slider.unobserve(self.slider_moved, names='value', type="change") + # self.slider.value = 0 + self.slider.max = self.frame_limits[self.segment_index] + self.slider.observe(self.slider_moved, names='value', type="change") + + self.update_time(new_frame=0, update_slider=True, update_label=True) + + + +class ScaleWidget(W.VBox): + def __init__(self, **kwargs): + scale_label = W.Label("Scale", + layout=W.Layout(layout=W.Layout(width='95%'), + justify_content="center")) + + self.plus_selector = W.Button( + description="", + disabled=False, + button_style="", # 'success', 'info', 'warning', 'danger' or '' + tooltip="Increase scale", + icon="arrow-up", + # layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), + layout=W.Layout(width='95%'), + ) + + self.minus_selector = W.Button( + description="", + disabled=False, + button_style="", # 'success', 'info', 'warning', 'danger' or '' + tooltip="Decrease scale", + icon="arrow-down", + # layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), + layout=W.Layout(width='95%'), + ) + + # controller = {"plus": plus_selector, "minus": minus_selector} + # widget = W.VBox([scale_label, plus_selector, minus_selector]) + + + super(W.VBox, self).__init__(children=[scale_label, self.plus_selector, self.minus_selector], + # layout=W.Layout(align_items="center", width="100%", height="100%"), + **kwargs) From 389737efe1330f1f75afb73caedb41bb6bf84b4d Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 21 Sep 2023 20:58:38 +0200 Subject: [PATCH 2/8] wip refactor plot traces ipywidget --- src/spikeinterface/widgets/traces.py | 126 ++++++++++++++---- .../widgets/utils_ipywidgets.py | 62 ++++++--- 2 files changed, 145 insertions(+), 43 deletions(-) diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index c6e36387f8..efd32ffb24 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -279,9 +279,9 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): import ipywidgets.widgets as W from .utils_ipywidgets import ( check_ipywidget_backend, - make_timeseries_controller, + # make_timeseries_controller, make_channel_controller, - make_scale_controller, + # make_scale_controller, TimeSlider, ScaleWidget, @@ -315,21 +315,22 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): - ts_widget, ts_controller = make_timeseries_controller( - t_start, - t_stop, - data_plot["layer_keys"], - rec0.get_num_segments(), - data_plot["time_range"], - data_plot["mode"], - False, - width_cm, - ) + # ts_widget, ts_controller = make_timeseries_controller( + # t_start, + # t_stop, + # data_plot["layer_keys"], + # rec0.get_num_segments(), + # data_plot["time_range"], + # data_plot["mode"], + # False, + # width_cm, + # ) # some widgets self.time_slider = TimeSlider( durations=[rec0.get_duration(s) for s in range(rec0.get_num_segments())], sampling_frequency=rec0.sampling_frequency, + # layout=W.Layout(height="2cm"), ) self.layer_selector = W.Dropdown(description="layer", options=data_plot["layer_keys"], layout=W.Layout(width="5cm"),) @@ -338,22 +339,22 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.scaler = ScaleWidget() left_sidebar = W.VBox( children=[self.layer_selector, self.mode_selector, self.scaler], - layout=W.Layout(width="5cm"), + layout=W.Layout(width="3.5cm"), ) ch_widget, ch_controller = make_channel_controller(rec0, width_cm=ratios[2] * width_cm, height_cm=height_cm) - scale_widget, scale_controller = make_scale_controller(width_cm=ratios[0] * width_cm, height_cm=height_cm) + # scale_widget, scale_controller = make_scale_controller(width_cm=ratios[0] * width_cm, height_cm=height_cm) - self.controller = ts_controller - self.controller.update(ch_controller) - self.controller.update(scale_controller) + # self.controller = ts_controller + # self.controller.update(ch_controller) + # self.controller.update(scale_controller) self.recordings = data_plot["recordings"] self.return_scaled = data_plot["return_scaled"] self.list_traces = None - self.actual_segment_index = self.controller["segment_index"].value + # self.actual_segment_index = self.controller["segment_index"].value self.rec0 = self.recordings[self.data_plot["layer_keys"][0]] self.t_stops = [ @@ -361,11 +362,11 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): for seg_index in range(self.rec0.get_num_segments()) ] - for w in self.controller.values(): - if isinstance(w, widgets.Button): - w.on_click(self._update_ipywidget) - else: - w.observe(self._update_ipywidget) + # for w in self.controller.values(): + # if isinstance(w, widgets.Button): + # w.on_click(self._update_ipywidget) + # else: + # w.observe(self._update_ipywidget) self.widget = widgets.AppLayout( center=self.figure.canvas, @@ -379,12 +380,89 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): ) # a first update - self._update_ipywidget(None) + # self._update_ipywidget(None) + + self._retrieve_traces() + self._update_plot() + + # only layer selector and time change generate a new traces retrieve + self.time_slider.observe(self._retrieve_traces, names='value', type="change") + self.layer_selector.observe(self._retrieve_traces, names='value', type="change") + # other widgets only refresh + self.scaler.observe(self._update_plot, names='value', type="change") + self.mode_selector.observe(self._update_plot, names='value', type="change") + if backend_kwargs["display"]: # self.check_backend() display(self.widget) + + + def _retrieve_traces(self, change=None): + # done when: + # * time or window is changes + # * layer is changed + + # TODO connect with channel selector + channel_ids = self.rec0.channel_ids + + # all_channel_ids = self.recordings[list(self.recordings.keys())[0]].channel_ids + # if self.data_plot["order"] is not None: + # all_channel_ids = all_channel_ids[self.data_plot["order"]] + # channel_ids = all_channel_ids[channel_indices] + if self.data_plot["order_channel_by_depth"]: + order, _ = order_channels_by_depth(self.rec0, channel_ids) + else: + order = None + + start_frame, end_frame, segment_index = self.time_slider.value + time_range = np.array([start_frame, end_frame]) / self.rec0.sampling_frequency + + times, list_traces, frame_range, channel_ids = _get_trace_list( + self.recordings, channel_ids, time_range, segment_index, order, self.return_scaled + ) + self.list_traces = list_traces + + self._update_plot() + + def _update_plot(self, change=None): + # done when: + # * time or window is changed (after _retrive_traces) + # * layer is changed (after _retrive_traces) + #  * scale is change + # * mode is change + + data_plot = self.next_data_plot + + # matplotlib next_data_plot dict update at each call + data_plot["mode"] = self.mode_selector.value + # data_plot["frame_range"] = frame_range + # data_plot["time_range"] = time_range + data_plot["with_colorbar"] = False + # data_plot["recordings"] = recordings + # data_plot["layer_keys"] = layer_keys + # data_plot["list_traces"] = list_traces_plot + # data_plot["times"] = times + # data_plot["clims"] = clims + # data_plot["channel_ids"] = channel_ids + + list_traces = [traces * self.scaler.value for traces in self.list_traces] + data_plot["list_traces"] = list_traces + + backend_kwargs = {} + backend_kwargs["ax"] = self.ax + + self.ax.clear() + self.plot_matplotlib(data_plot, **backend_kwargs) + + fig = self.ax.figure + fig.canvas.draw() + fig.canvas.flush_events() + + + + def _update_ipywidget(self, change): import ipywidgets.widgets as widgets diff --git a/src/spikeinterface/widgets/utils_ipywidgets.py b/src/spikeinterface/widgets/utils_ipywidgets.py index 674a2d2cc7..ad0ead7bc0 100644 --- a/src/spikeinterface/widgets/utils_ipywidgets.py +++ b/src/spikeinterface/widgets/utils_ipywidgets.py @@ -109,7 +109,7 @@ def make_scale_controller(width_cm, height_cm): class TimeSlider(W.HBox): - position = traitlets.Tuple(traitlets.Int(), traitlets.Int(), traitlets.Int()) + value = traitlets.Tuple(traitlets.Int(), traitlets.Int(), traitlets.Int()) def __init__(self, durations, sampling_frequency, time_range=(0, 1.), **kwargs): @@ -123,10 +123,10 @@ def __init__(self, durations, sampling_frequency, time_range=(0, 1.), **kwargs): self.frame_range = (start_frame, end_frame) self.segment_index = 0 - self.position = (start_frame, end_frame, self.segment_index) + self.value = (start_frame, end_frame, self.segment_index) - layout = W.Layout(align_items="center", width="1.5cm", height="100%") + layout = W.Layout(align_items="center", width="2cm", hight="1.5cm") but_left = W.Button(description='', disabled=False, button_style='', icon='arrow-left', layout=layout) but_right = W.Button(description='', disabled=False, button_style='', icon='arrow-right', layout=layout) @@ -176,21 +176,21 @@ def __init__(self, durations, sampling_frequency, time_range=(0, 1.), **kwargs): layout=W.Layout(align_items="center", width="100%", height="100%"), **kwargs) - self.observe(self.position_changed, names=['position'], type="change") + self.observe(self.value_changed, names=['value'], type="change") - def position_changed(self, change=None): + def value_changed(self, change=None): - self.unobserve(self.position_changed, names=['position'], type="change") + self.unobserve(self.value_changed, names=['value'], type="change") - start, stop, seg_index = self.position + start, stop, seg_index = self.value if seg_index < 0 or seg_index >= self.num_segments: - self.position = change['old'] + self.value = change['old'] return if start < 0 or stop < 0: - self.position = change['old'] + self.value = change['old'] return if start >= self.frame_limits[seg_index] or start > self.frame_limits[seg_index]: - self.position = change['old'] + self.value = change['old'] return self.segment_selector.value = seg_index @@ -198,7 +198,7 @@ def position_changed(self, change=None): delta_s = (stop - start) / self.sampling_frequency self.window_sizer.value = delta_s - self.observe(self.position_changed, names=['position'], type="change") + self.observe(self.value_changed, names=['value'], type="change") def update_time(self, new_frame=None, new_time=None, update_slider=False, update_label=False): if new_frame is None and new_time is None: @@ -228,6 +228,7 @@ def update_time(self, new_frame=None, new_time=None, update_slider=False, update self.slider.observe(self.slider_moved, names='value', type="change") self.frame_range = (start_frame, end_frame) + self.value = (start_frame, end_frame, self.segment_index) def time_label_changed(self, change=None): try: @@ -273,8 +274,14 @@ def segment_changed(self, change=None): class ScaleWidget(W.VBox): - def __init__(self, **kwargs): - scale_label = W.Label("Scale", + value = traitlets.Float() + + def __init__(self, value=1., factor=1.2, **kwargs): + + assert factor > 1. + self.factor = factor + + self.scale_label = W.Label("Scale", layout=W.Layout(layout=W.Layout(width='95%'), justify_content="center")) @@ -285,7 +292,7 @@ def __init__(self, **kwargs): tooltip="Increase scale", icon="arrow-up", # layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), - layout=W.Layout(width='95%'), + layout=W.Layout(width='60%', align_self='center'), ) self.minus_selector = W.Button( @@ -295,13 +302,30 @@ def __init__(self, **kwargs): tooltip="Decrease scale", icon="arrow-down", # layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), - layout=W.Layout(width='95%'), + layout=W.Layout(width='60%', align_self='center'), ) - # controller = {"plus": plus_selector, "minus": minus_selector} - # widget = W.VBox([scale_label, plus_selector, minus_selector]) + self.plus_selector.on_click(self.plus_clicked) + self.minus_selector.on_click(self.minus_clicked) - - super(W.VBox, self).__init__(children=[scale_label, self.plus_selector, self.minus_selector], + self.value = 1. + super(W.VBox, self).__init__(children=[self.plus_selector, self.scale_label, self.minus_selector], # layout=W.Layout(align_items="center", width="100%", height="100%"), **kwargs) + + self.update_label() + self.observe(self.value_changed, names=['value'], type="change") + + def update_label(self): + self.scale_label.value = f"Scale: {self.value:0.2f}" + + + def plus_clicked(self, change=None): + self.value = self.value * self.factor + + def minus_clicked(self, change=None): + self.value = self.value / self.factor + + + def value_changed(self, change=None): + self.update_label() From e5995f2aa6445fd878e1c0881f11299f8ae22a2d Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 21 Sep 2023 22:59:53 +0200 Subject: [PATCH 3/8] ipywidget backend refactor wip --- src/spikeinterface/widgets/traces.py | 298 +++++------------- .../widgets/utils_ipywidgets.py | 175 ++++++---- 2 files changed, 190 insertions(+), 283 deletions(-) diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index efd32ffb24..d107c5cb23 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -280,23 +280,23 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): from .utils_ipywidgets import ( check_ipywidget_backend, # make_timeseries_controller, - make_channel_controller, + # make_channel_controller, # make_scale_controller, - TimeSlider, + ChannelSelector, ScaleWidget, - ) check_ipywidget_backend() self.next_data_plot = data_plot.copy() - self.next_data_plot["add_legend"] = False + - recordings = data_plot["recordings"] + self.recordings = data_plot["recordings"] # first layer - rec0 = recordings[data_plot["layer_keys"][0]] + # rec0 = recordings[data_plot["layer_keys"][0]] + rec0 = self.rec0 = self.recordings[self.data_plot["layer_keys"][0]] cm = 1 / 2.54 @@ -310,107 +310,92 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.figure, self.ax = plt.subplots(figsize=(0.9 * ratios[1] * width_cm * cm, height_cm * cm)) plt.show() - t_start = 0.0 - t_stop = rec0.get_num_samples(segment_index=0) / rec0.get_sampling_frequency() - - - - # ts_widget, ts_controller = make_timeseries_controller( - # t_start, - # t_stop, - # data_plot["layer_keys"], - # rec0.get_num_segments(), - # data_plot["time_range"], - # data_plot["mode"], - # False, - # width_cm, - # ) - # some widgets self.time_slider = TimeSlider( durations=[rec0.get_duration(s) for s in range(rec0.get_num_segments())], sampling_frequency=rec0.sampling_frequency, # layout=W.Layout(height="2cm"), ) - self.layer_selector = W.Dropdown(description="layer", options=data_plot["layer_keys"], - layout=W.Layout(width="5cm"),) - self.mode_selector = W.Dropdown(options=["line", "map"], description="mode", value=data_plot["mode"], - layout=W.Layout(width="5cm"),) + + start_frame = int(data_plot["time_range"][0] * rec0.sampling_frequency) + end_frame = int(data_plot["time_range"][1] * rec0.sampling_frequency) + + self.time_slider.value = start_frame, end_frame, data_plot["segment_index"] + + _layer_keys = data_plot["layer_keys"] + if len(_layer_keys) > 1: + _layer_keys = ['ALL'] + _layer_keys + self.layer_selector = W.Dropdown(options=_layer_keys, + layout=W.Layout(width="95%"), + ) + self.mode_selector = W.Dropdown(options=["line", "map"], value=data_plot["mode"], + # layout=W.Layout(width="5cm"), + layout=W.Layout(width="95%"), + ) self.scaler = ScaleWidget() + self.channel_selector = ChannelSelector(self.rec0.channel_ids) + left_sidebar = W.VBox( - children=[self.layer_selector, self.mode_selector, self.scaler], + children=[ + W.Label(value="layer"), + self.layer_selector, + W.Label(value="mode"), + self.mode_selector, + self.scaler, + # self.channel_selector, + ], layout=W.Layout(width="3.5cm"), + align_items='center', ) - - ch_widget, ch_controller = make_channel_controller(rec0, width_cm=ratios[2] * width_cm, height_cm=height_cm) - - # scale_widget, scale_controller = make_scale_controller(width_cm=ratios[0] * width_cm, height_cm=height_cm) - - # self.controller = ts_controller - # self.controller.update(ch_controller) - # self.controller.update(scale_controller) - - self.recordings = data_plot["recordings"] self.return_scaled = data_plot["return_scaled"] - self.list_traces = None - # self.actual_segment_index = self.controller["segment_index"].value - - self.rec0 = self.recordings[self.data_plot["layer_keys"][0]] - self.t_stops = [ - self.rec0.get_num_samples(segment_index=seg_index) / self.rec0.get_sampling_frequency() - for seg_index in range(self.rec0.get_num_segments()) - ] - - # for w in self.controller.values(): - # if isinstance(w, widgets.Button): - # w.on_click(self._update_ipywidget) - # else: - # w.observe(self._update_ipywidget) self.widget = widgets.AppLayout( center=self.figure.canvas, - # footer=ts_widget, footer=self.time_slider, - # left_sidebar=scale_widget, left_sidebar = left_sidebar, - right_sidebar=ch_widget, + right_sidebar=self.channel_selector, pane_heights=[0, 6, 1], pane_widths=ratios, ) # a first update - # self._update_ipywidget(None) - self._retrieve_traces() self._update_plot() - # only layer selector and time change generate a new traces retrieve + # callbacks: + # some widgets generate a full retrieve + refresh self.time_slider.observe(self._retrieve_traces, names='value', type="change") self.layer_selector.observe(self._retrieve_traces, names='value', type="change") + self.channel_selector.observe(self._retrieve_traces, names='value', type="change") # other widgets only refresh self.scaler.observe(self._update_plot, names='value', type="change") - self.mode_selector.observe(self._update_plot, names='value', type="change") + # map is a special case because needs to check layer also + self.mode_selector.observe(self._mode_changed, names='value', type="change") - if backend_kwargs["display"]: # self.check_backend() display(self.widget) - + def _get_layers(self): + layer = self.layer_selector.value + if layer == 'ALL': + layer_keys = self.data_plot["layer_keys"] + else: + layer_keys = [layer] + if self.mode_selector.value == "map": + layer_keys = layer_keys[:1] + return layer_keys + + def _mode_changed(self, change=None): + if self.mode_selector.value == "map" and self.layer_selector.value == "ALL": + self.layer_selector.value = self.data_plot["layer_keys"][0] + else: + self._update_plot() def _retrieve_traces(self, change=None): - # done when: - # * time or window is changes - # * layer is changed + channel_ids = np.array(self.channel_selector.value) - # TODO connect with channel selector - channel_ids = self.rec0.channel_ids - - # all_channel_ids = self.recordings[list(self.recordings.keys())[0]].channel_ids - # if self.data_plot["order"] is not None: - # all_channel_ids = all_channel_ids[self.data_plot["order"]] - # channel_ids = all_channel_ids[channel_indices] if self.data_plot["order_channel_by_depth"]: order, _ = order_channels_by_depth(self.rec0, channel_ids) else: @@ -419,176 +404,61 @@ def _retrieve_traces(self, change=None): start_frame, end_frame, segment_index = self.time_slider.value time_range = np.array([start_frame, end_frame]) / self.rec0.sampling_frequency + self._selected_recordings = {k: self.recordings[k] for k in self._get_layers()} times, list_traces, frame_range, channel_ids = _get_trace_list( - self.recordings, channel_ids, time_range, segment_index, order, self.return_scaled + self._selected_recordings, channel_ids, time_range, segment_index, order, self.return_scaled ) - self.list_traces = list_traces + + self._channel_ids = channel_ids + self._list_traces = list_traces + self._times = times + self._time_range = time_range + self._frame_range = (start_frame, end_frame) + self._segment_index = segment_index self._update_plot() def _update_plot(self, change=None): - # done when: - # * time or window is changed (after _retrive_traces) - # * layer is changed (after _retrive_traces) - #  * scale is change - # * mode is change - data_plot = self.next_data_plot # matplotlib next_data_plot dict update at each call - data_plot["mode"] = self.mode_selector.value - # data_plot["frame_range"] = frame_range - # data_plot["time_range"] = time_range - data_plot["with_colorbar"] = False - # data_plot["recordings"] = recordings - # data_plot["layer_keys"] = layer_keys - # data_plot["list_traces"] = list_traces_plot - # data_plot["times"] = times - # data_plot["clims"] = clims - # data_plot["channel_ids"] = channel_ids - - list_traces = [traces * self.scaler.value for traces in self.list_traces] - data_plot["list_traces"] = list_traces - - backend_kwargs = {} - backend_kwargs["ax"] = self.ax - - self.ax.clear() - self.plot_matplotlib(data_plot, **backend_kwargs) - - fig = self.ax.figure - fig.canvas.draw() - fig.canvas.flush_events() - - - - - def _update_ipywidget(self, change): - import ipywidgets.widgets as widgets - - # if changing the layer_key, no need to retrieve and process traces - retrieve_traces = True - scale_up = False - scale_down = False - if change is not None: - for cname, c in self.controller.items(): - if isinstance(change, dict): - if change["owner"] is c and cname == "layer_key": - retrieve_traces = False - elif isinstance(change, widgets.Button): - if change is c and cname == "plus": - scale_up = True - if change is c and cname == "minus": - scale_down = True - - t_start = self.controller["t_start"].value - window = self.controller["window"].value - layer_key = self.controller["layer_key"].value - segment_index = self.controller["segment_index"].value - mode = self.controller["mode"].value - chan_start, chan_stop = self.controller["channel_inds"].value - - if mode == "line": - self.controller["all_layers"].layout.visibility = "visible" - all_layers = self.controller["all_layers"].value - elif mode == "map": - self.controller["all_layers"].layout.visibility = "hidden" - all_layers = False - - if all_layers: - self.controller["layer_key"].layout.visibility = "hidden" - else: - self.controller["layer_key"].layout.visibility = "visible" - - if chan_start == chan_stop: - chan_stop += 1 - channel_indices = np.arange(chan_start, chan_stop) - - t_stop = self.t_stops[segment_index] - if self.actual_segment_index != segment_index: - # change time_slider limits - self.controller["t_start"].max = t_stop - self.actual_segment_index = segment_index - - # protect limits - if t_start >= t_stop - window: - t_start = t_stop - window - - time_range = np.array([t_start, t_start + window]) - data_plot = self.next_data_plot + mode = self.mode_selector.value + layer_keys = self._get_layers() - if retrieve_traces: - all_channel_ids = self.recordings[list(self.recordings.keys())[0]].channel_ids - if self.data_plot["order"] is not None: - all_channel_ids = all_channel_ids[self.data_plot["order"]] - channel_ids = all_channel_ids[channel_indices] - if self.data_plot["order_channel_by_depth"]: - order, _ = order_channels_by_depth(self.rec0, channel_ids) - else: - order = None - times, list_traces, frame_range, channel_ids = _get_trace_list( - self.recordings, channel_ids, time_range, segment_index, order, self.return_scaled - ) - self.list_traces = list_traces - else: - times = data_plot["times"] - list_traces = data_plot["list_traces"] - frame_range = data_plot["frame_range"] - channel_ids = data_plot["channel_ids"] - - if all_layers: - layer_keys = self.data_plot["layer_keys"] - recordings = self.recordings - list_traces_plot = self.list_traces - else: - layer_keys = [layer_key] - recordings = {layer_key: self.recordings[layer_key]} - list_traces_plot = [self.list_traces[list(self.recordings.keys()).index(layer_key)]] - - if scale_up: - if mode == "line": - data_plot["vspacing"] *= 0.8 - elif mode == "map": - data_plot["clims"] = { - layer: (1.2 * val[0], 1.2 * val[1]) for layer, val in self.data_plot["clims"].items() - } - if scale_down: - if mode == "line": - data_plot["vspacing"] *= 1.2 - elif mode == "map": - data_plot["clims"] = { - layer: (0.8 * val[0], 0.8 * val[1]) for layer, val in self.data_plot["clims"].items() - } - - self.next_data_plot["vspacing"] = data_plot["vspacing"] - self.next_data_plot["clims"] = data_plot["clims"] + data_plot["mode"] = mode + data_plot["frame_range"] = self._frame_range + data_plot["time_range"] = self._time_range + data_plot["with_colorbar"] = False + data_plot["recordings"] = self._selected_recordings + data_plot["add_legend"] = False if mode == "line": clims = None elif mode == "map": - clims = {layer_key: self.data_plot["clims"][layer_key]} + clims = {k: self.data_plot["clims"][k] for k in layer_keys} - # matplotlib next_data_plot dict update at each call - data_plot["mode"] = mode - data_plot["frame_range"] = frame_range - data_plot["time_range"] = time_range - data_plot["with_colorbar"] = False - data_plot["recordings"] = recordings - data_plot["layer_keys"] = layer_keys - data_plot["list_traces"] = list_traces_plot - data_plot["times"] = times data_plot["clims"] = clims - data_plot["channel_ids"] = channel_ids + data_plot["channel_ids"] = self._channel_ids + + data_plot["layer_keys"] = layer_keys + data_plot["colors"] = {k:self.data_plot["colors"][k] for k in layer_keys} + + list_traces = [traces * self.scaler.value for traces in self._list_traces] + data_plot["list_traces"] = list_traces + data_plot["times"] = self._times backend_kwargs = {} backend_kwargs["ax"] = self.ax + self.ax.clear() self.plot_matplotlib(data_plot, **backend_kwargs) + self.ax.set_title("") fig = self.ax.figure fig.canvas.draw() fig.canvas.flush_events() + def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv from .utils_sortingview import handle_display_and_url diff --git a/src/spikeinterface/widgets/utils_ipywidgets.py b/src/spikeinterface/widgets/utils_ipywidgets.py index ad0ead7bc0..ab2b51a7bb 100644 --- a/src/spikeinterface/widgets/utils_ipywidgets.py +++ b/src/spikeinterface/widgets/utils_ipywidgets.py @@ -11,35 +11,35 @@ def check_ipywidget_backend(): assert "ipympl" in mpl_backend, "To use the 'ipywidgets' backend, you have to set %matplotlib widget" -def make_timeseries_controller(t_start, t_stop, layer_keys, num_segments, time_range, mode, all_layers, width_cm): - time_slider = W.FloatSlider( - orientation="horizontal", - description="time:", - value=time_range[0], - min=t_start, - max=t_stop, - continuous_update=False, - layout=W.Layout(width=f"{width_cm}cm"), - ) - layer_selector = W.Dropdown(description="layer", options=layer_keys) - segment_selector = W.Dropdown(description="segment", options=list(range(num_segments))) - window_sizer = W.BoundedFloatText(value=np.diff(time_range)[0], step=0.1, min=0.005, description="win (s)") - mode_selector = W.Dropdown(options=["line", "map"], description="mode", value=mode) - all_layers = W.Checkbox(description="plot all layers", value=all_layers) - - controller = { - "layer_key": layer_selector, - "segment_index": segment_selector, - "window": window_sizer, - "t_start": time_slider, - "mode": mode_selector, - "all_layers": all_layers, - } - widget = W.VBox( - [time_slider, W.HBox([all_layers, layer_selector, segment_selector, window_sizer, mode_selector])] - ) - - return widget, controller +# def make_timeseries_controller(t_start, t_stop, layer_keys, num_segments, time_range, mode, all_layers, width_cm): +# time_slider = W.FloatSlider( +# orientation="horizontal", +# description="time:", +# value=time_range[0], +# min=t_start, +# max=t_stop, +# continuous_update=False, +# layout=W.Layout(width=f"{width_cm}cm"), +# ) +# layer_selector = W.Dropdown(description="layer", options=layer_keys) +# segment_selector = W.Dropdown(description="segment", options=list(range(num_segments))) +# window_sizer = W.BoundedFloatText(value=np.diff(time_range)[0], step=0.1, min=0.005, description="win (s)") +# mode_selector = W.Dropdown(options=["line", "map"], description="mode", value=mode) +# all_layers = W.Checkbox(description="plot all layers", value=all_layers) + +# controller = { +# "layer_key": layer_selector, +# "segment_index": segment_selector, +# "window": window_sizer, +# "t_start": time_slider, +# "mode": mode_selector, +# "all_layers": all_layers, +# } +# widget = W.VBox( +# [time_slider, W.HBox([all_layers, layer_selector, segment_selector, window_sizer, mode_selector])] +# ) + +# return widget, controller def make_unit_controller(unit_ids, all_unit_ids, width_cm, height_cm): @@ -58,52 +58,52 @@ def make_unit_controller(unit_ids, all_unit_ids, width_cm, height_cm): return widget, controller -def make_channel_controller(recording, width_cm, height_cm): - channel_label = W.Label("channel indices:", layout=W.Layout(justify_content="center")) - channel_selector = W.IntRangeSlider( - value=[0, recording.get_num_channels()], - min=0, - max=recording.get_num_channels(), - step=1, - disabled=False, - continuous_update=False, - orientation="vertical", - readout=True, - readout_format="d", - layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{height_cm}cm"), - ) +# def make_channel_controller(recording, width_cm, height_cm): +# channel_label = W.Label("channel indices:", layout=W.Layout(justify_content="center")) +# channel_selector = W.IntRangeSlider( +# value=[0, recording.get_num_channels()], +# min=0, +# max=recording.get_num_channels(), +# step=1, +# disabled=False, +# continuous_update=False, +# orientation="vertical", +# readout=True, +# readout_format="d", +# layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{height_cm}cm"), +# ) - controller = {"channel_inds": channel_selector} - widget = W.VBox([channel_label, channel_selector]) +# controller = {"channel_inds": channel_selector} +# widget = W.VBox([channel_label, channel_selector]) - return widget, controller +# return widget, controller -def make_scale_controller(width_cm, height_cm): - scale_label = W.Label("Scale", layout=W.Layout(justify_content="center")) +# def make_scale_controller(width_cm, height_cm): +# scale_label = W.Label("Scale", layout=W.Layout(justify_content="center")) - plus_selector = W.Button( - description="", - disabled=False, - button_style="", # 'success', 'info', 'warning', 'danger' or '' - tooltip="Increase scale", - icon="arrow-up", - layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), - ) +# plus_selector = W.Button( +# description="", +# disabled=False, +# button_style="", # 'success', 'info', 'warning', 'danger' or '' +# tooltip="Increase scale", +# icon="arrow-up", +# layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), +# ) - minus_selector = W.Button( - description="", - disabled=False, - button_style="", # 'success', 'info', 'warning', 'danger' or '' - tooltip="Decrease scale", - icon="arrow-down", - layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), - ) +# minus_selector = W.Button( +# description="", +# disabled=False, +# button_style="", # 'success', 'info', 'warning', 'danger' or '' +# tooltip="Decrease scale", +# icon="arrow-down", +# layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), +# ) - controller = {"plus": plus_selector, "minus": minus_selector} - widget = W.VBox([scale_label, plus_selector, minus_selector]) +# controller = {"plus": plus_selector, "minus": minus_selector} +# widget = W.VBox([scale_label, plus_selector, minus_selector]) - return widget, controller +# return widget, controller @@ -126,7 +126,7 @@ def __init__(self, durations, sampling_frequency, time_range=(0, 1.), **kwargs): self.value = (start_frame, end_frame, self.segment_index) - layout = W.Layout(align_items="center", width="2cm", hight="1.5cm") + layout = W.Layout(align_items="center", width="2.5cm", height="1.cm") but_left = W.Button(description='', disabled=False, button_style='', icon='arrow-left', layout=layout) but_right = W.Button(description='', disabled=False, button_style='', icon='arrow-right', layout=layout) @@ -141,7 +141,7 @@ def __init__(self, durations, sampling_frequency, time_range=(0, 1.), **kwargs): # DatetimePicker is only for ipywidget v8 (which is not working in vscode 2023-03) self.time_label = W.Text(value=f'{time_range[0]}',description='', - disabled=False, layout=W.Layout(width='5.5cm')) + disabled=False, layout=W.Layout(width='2.5cm')) self.time_label.observe(self.time_label_changed, names='value', type="change") @@ -271,6 +271,43 @@ def segment_changed(self, change=None): self.update_time(new_frame=0, update_slider=True, update_label=True) +class ChannelSelector(W.VBox): + value = traitlets.List() + + def __init__(self, channel_ids, **kwargs): + self.channel_ids = list(channel_ids) + self.value = self.channel_ids + + channel_label = W.Label("Channels", layout=W.Layout(justify_content="center")) + n = len(channel_ids) + self.slider = W.IntRangeSlider( + value=[0, n], + min=0, + max=n, + step=1, + disabled=False, + continuous_update=False, + orientation="vertical", + readout=True, + readout_format="d", + # layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{height_cm}cm"), + layout=W.Layout(height="100%"), + ) + + + + super(W.VBox, self).__init__(children=[channel_label, self.slider], + layout=W.Layout(align_items="center"), + # layout=W.Layout(align_items="center", width="100%", height="100%"), + **kwargs) + self.slider.observe(self.on_slider_changed, names=['value'], type="change") + # self.update_label() + # self.observe(self.value_changed, names=['value'], type="change") + + def on_slider_changed(self, change=None): + i0, i1 = self.slider.value + self.value = self.channel_ids[i0:i1] + class ScaleWidget(W.VBox): From 7b92c2153d4fad412823100fd77079e3cf286138 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 22 Sep 2023 08:06:37 +0200 Subject: [PATCH 4/8] improve channel selector --- .../widgets/utils_ipywidgets.py | 38 +++++++++++++++++-- 1 file changed, 35 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/widgets/utils_ipywidgets.py b/src/spikeinterface/widgets/utils_ipywidgets.py index ab2b51a7bb..705dd09f23 100644 --- a/src/spikeinterface/widgets/utils_ipywidgets.py +++ b/src/spikeinterface/widgets/utils_ipywidgets.py @@ -294,20 +294,52 @@ def __init__(self, channel_ids, **kwargs): layout=W.Layout(height="100%"), ) + # first channel are bottom: need reverse + self.selector = W.SelectMultiple( + options=self.channel_ids[::-1], + value=self.channel_ids[::-1], + disabled=False, + # layout=W.Layout(width=f"{width_cm}cm", height=f"{height_cm}cm"), + layout=W.Layout(height="100%", width="2cm"), + ) + hbox = W.HBox(children=[self.slider, self.selector]) - - super(W.VBox, self).__init__(children=[channel_label, self.slider], + super(W.VBox, self).__init__(children=[channel_label, hbox], layout=W.Layout(align_items="center"), # layout=W.Layout(align_items="center", width="100%", height="100%"), **kwargs) self.slider.observe(self.on_slider_changed, names=['value'], type="change") - # self.update_label() + self.selector.observe(self.on_selector_changed, names=['value'], type="change") + + # TODO external value change # self.observe(self.value_changed, names=['value'], type="change") def on_slider_changed(self, change=None): i0, i1 = self.slider.value + + self.selector.unobserve(self.on_selector_changed, names=['value'], type="change") + self.selector.value = self.channel_ids[i0:i1][::-1] + self.selector.observe(self.on_selector_changed, names=['value'], type="change") + self.value = self.channel_ids[i0:i1] + def on_selector_changed(self, change=None): + channel_ids = self.selector.value + channel_ids = channel_ids[::-1] + + if len(channel_ids) > 0: + self.slider.unobserve(self.on_slider_changed, names=['value'], type="change") + i0 = self.channel_ids.index(channel_ids[0]) + i1 = self.channel_ids.index(channel_ids[-1]) + 1 + self.slider.value = (i0, i1) + self.slider.observe(self.on_slider_changed, names=['value'], type="change") + + self.value = channel_ids + + + + + class ScaleWidget(W.VBox): From 2e305586d5b39bb8bfa89280057579a97726e93a Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 22 Sep 2023 11:09:05 +0200 Subject: [PATCH 5/8] ipywidgets backend start UnitCOntroller --- src/spikeinterface/widgets/amplitudes.py | 69 ++++++++++--------- .../widgets/utils_ipywidgets.py | 39 +++++++++-- 2 files changed, 71 insertions(+), 37 deletions(-) diff --git a/src/spikeinterface/widgets/amplitudes.py b/src/spikeinterface/widgets/amplitudes.py index 7ef6e0ff61..b60de98cb0 100644 --- a/src/spikeinterface/widgets/amplitudes.py +++ b/src/spikeinterface/widgets/amplitudes.py @@ -171,9 +171,10 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt - import ipywidgets.widgets as widgets + # import ipywidgets.widgets as widgets + import ipywidgets.widgets as W from IPython.display import display - from .utils_ipywidgets import check_ipywidget_backend, make_unit_controller + from .utils_ipywidgets import check_ipywidget_backend, make_unit_controller, UnitSelector check_ipywidget_backend() @@ -188,60 +189,62 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): ratios = [0.15, 0.85] with plt.ioff(): - output = widgets.Output() + output = W.Output() with output: self.figure = plt.figure(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) plt.show() - data_plot["unit_ids"] = data_plot["unit_ids"][:1] - unit_widget, unit_controller = make_unit_controller( - data_plot["unit_ids"], we.unit_ids, ratios[0] * width_cm, height_cm - ) + self.unit_selector = UnitSelector(we.unit_ids) + self.unit_selector.value = list(we.unit_ids)[:1] - plot_histograms = widgets.Checkbox( + self.checkbox_histograms = W.Checkbox( value=data_plot["plot_histograms"], - description="plot histograms", - disabled=False, + description="hist", + # disabled=False, ) - footer = plot_histograms - - self.controller = {"plot_histograms": plot_histograms} - self.controller.update(unit_controller) - - for w in self.controller.values(): - w.observe(self._update_ipywidget) + left_sidebar = W.VBox( + children=[ + self.unit_selector, + self.checkbox_histograms, + ], + layout = W.Layout(align_items="center", width="4cm", height="100%"), + ) - self.widget = widgets.AppLayout( + self.widget = W.AppLayout( center=self.figure.canvas, - left_sidebar=unit_widget, + left_sidebar=left_sidebar, pane_widths=ratios + [0], - footer=footer, ) # a first update - self._update_ipywidget(None) + self._full_update_plot() + + self.unit_selector.observe(self._update_plot, names='value', type="change") + self.checkbox_histograms.observe(self._full_update_plot, names='value', type="change") if backend_kwargs["display"]: display(self.widget) - def _update_ipywidget(self, change): + def _full_update_plot(self, change=None): self.figure.clear() + data_plot = self.next_data_plot + data_plot["unit_ids"] = self.unit_selector.value + data_plot["plot_histograms"] = self.checkbox_histograms.value + + backend_kwargs = dict(figure=self.figure, axes=None, ax=None) + self.plot_matplotlib(data_plot, **backend_kwargs) + self._update_plot() - unit_ids = self.controller["unit_ids"].value - plot_histograms = self.controller["plot_histograms"].value + def _update_plot(self, change=None): + for ax in self.axes.flatten(): + ax.clear() - # matplotlib next_data_plot dict update at each call data_plot = self.next_data_plot - data_plot["unit_ids"] = unit_ids - data_plot["plot_histograms"] = plot_histograms - - backend_kwargs = {} - # backend_kwargs["figure"] = self.fig - backend_kwargs["figure"] = self.figure - backend_kwargs["axes"] = None - backend_kwargs["ax"] = None + data_plot["unit_ids"] = self.unit_selector.value + data_plot["plot_histograms"] = self.checkbox_histograms.value + backend_kwargs = dict(figure=None, axes=self.axes, ax=None) self.plot_matplotlib(data_plot, **backend_kwargs) self.figure.canvas.draw() diff --git a/src/spikeinterface/widgets/utils_ipywidgets.py b/src/spikeinterface/widgets/utils_ipywidgets.py index 705dd09f23..d2c41f234a 100644 --- a/src/spikeinterface/widgets/utils_ipywidgets.py +++ b/src/spikeinterface/widgets/utils_ipywidgets.py @@ -338,10 +338,6 @@ def on_selector_changed(self, change=None): - - - - class ScaleWidget(W.VBox): value = traitlets.Float() @@ -398,3 +394,38 @@ def minus_clicked(self, change=None): def value_changed(self, change=None): self.update_label() + + +class UnitSelector(W.VBox): + value = traitlets.List() + + def __init__(self, unit_ids, **kwargs): + self.unit_ids = list(unit_ids) + self.value = self.unit_ids + + label = W.Label("Units", layout=W.Layout(justify_content="center")) + + self.selector = W.SelectMultiple( + options=self.unit_ids, + value=self.unit_ids, + disabled=False, + layout=W.Layout(height="100%", width="2cm"), + ) + + super(W.VBox, self).__init__(children=[label, self.selector], + layout=W.Layout(align_items="center"), + **kwargs) + + self.selector.observe(self.on_selector_changed, names=['value'], type="change") + + self.observe(self.value_changed, names=['value'], type="change") + + def on_selector_changed(self, change=None): + unit_ids = self.selector.value + self.value = unit_ids + + def value_changed(self, change=None): + self.selector.unobserve(self.on_selector_changed, names=['value'], type="change") + self.selector.value = change['new'] + self.selector.observe(self.on_selector_changed, names=['value'], type="change") + From 4e79b5811d41e6343391a3a6b26fab97f657368b Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 22 Sep 2023 13:32:51 +0200 Subject: [PATCH 6/8] propagate UnitSelector to others ipywidgets --- src/spikeinterface/widgets/amplitudes.py | 12 ++- src/spikeinterface/widgets/base.py | 3 +- src/spikeinterface/widgets/metrics.py | 21 ++-- src/spikeinterface/widgets/spike_locations.py | 34 +++---- .../widgets/spikes_on_traces.py | 87 ++++++++++------- src/spikeinterface/widgets/unit_locations.py | 29 +++--- src/spikeinterface/widgets/unit_waveforms.py | 50 +++++----- .../widgets/utils_ipywidgets.py | 96 ------------------- 8 files changed, 121 insertions(+), 211 deletions(-) diff --git a/src/spikeinterface/widgets/amplitudes.py b/src/spikeinterface/widgets/amplitudes.py index b60de98cb0..5aa090b1b4 100644 --- a/src/spikeinterface/widgets/amplitudes.py +++ b/src/spikeinterface/widgets/amplitudes.py @@ -147,13 +147,16 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): else: bins = dp.bins ax_hist = self.axes.flatten()[1] - ax_hist.hist(amps, bins=bins, orientation="horizontal", color=dp.unit_colors[unit_id], alpha=0.8) + # this is super slow, using plot and np.histogram is really much faster (and nicer!) + # ax_hist.hist(amps, bins=bins, orientation="horizontal", color=dp.unit_colors[unit_id], alpha=0.8) + count, bins = np.histogram(amps, bins=bins) + ax_hist.plot(count, bins[:-1], color=dp.unit_colors[unit_id], alpha=0.8) if dp.plot_histograms: ax_hist = self.axes.flatten()[1] ax_hist.set_ylim(scatter_ax.get_ylim()) ax_hist.axis("off") - self.figure.tight_layout() + # self.figure.tight_layout() if dp.plot_legend: if hasattr(self, "legend") and self.legend is not None: @@ -174,7 +177,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): # import ipywidgets.widgets as widgets import ipywidgets.widgets as W from IPython.display import display - from .utils_ipywidgets import check_ipywidget_backend, make_unit_controller, UnitSelector + from .utils_ipywidgets import check_ipywidget_backend, UnitSelector check_ipywidget_backend() @@ -200,7 +203,6 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.checkbox_histograms = W.Checkbox( value=data_plot["plot_histograms"], description="hist", - # disabled=False, ) left_sidebar = W.VBox( @@ -231,6 +233,7 @@ def _full_update_plot(self, change=None): data_plot = self.next_data_plot data_plot["unit_ids"] = self.unit_selector.value data_plot["plot_histograms"] = self.checkbox_histograms.value + data_plot["plot_legend"] = False backend_kwargs = dict(figure=self.figure, axes=None, ax=None) self.plot_matplotlib(data_plot, **backend_kwargs) @@ -243,6 +246,7 @@ def _update_plot(self, change=None): data_plot = self.next_data_plot data_plot["unit_ids"] = self.unit_selector.value data_plot["plot_histograms"] = self.checkbox_histograms.value + data_plot["plot_legend"] = False backend_kwargs = dict(figure=None, axes=self.axes, ax=None) self.plot_matplotlib(data_plot, **backend_kwargs) diff --git a/src/spikeinterface/widgets/base.py b/src/spikeinterface/widgets/base.py index 4ed83fcca9..1ff691320a 100644 --- a/src/spikeinterface/widgets/base.py +++ b/src/spikeinterface/widgets/base.py @@ -38,6 +38,7 @@ def set_default_plotter_backend(backend): "width_cm": "Width of the figure in cm (default 10)", "height_cm": "Height of the figure in cm (default 6)", "display": "If True, widgets are immediately displayed", + # "controllers": "" }, "ephyviewer": {}, } @@ -45,7 +46,7 @@ def set_default_plotter_backend(backend): default_backend_kwargs = { "matplotlib": {"figure": None, "ax": None, "axes": None, "ncols": 5, "figsize": None, "figtitle": None}, "sortingview": {"generate_url": True, "display": True, "figlabel": None, "height": None}, - "ipywidgets": {"width_cm": 25, "height_cm": 10, "display": True}, + "ipywidgets": {"width_cm": 25, "height_cm": 10, "display": True, "controllers": None}, "ephyviewer": {}, } diff --git a/src/spikeinterface/widgets/metrics.py b/src/spikeinterface/widgets/metrics.py index 9dc51f522e..604da35e65 100644 --- a/src/spikeinterface/widgets/metrics.py +++ b/src/spikeinterface/widgets/metrics.py @@ -128,7 +128,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt import ipywidgets.widgets as widgets from IPython.display import display - from .utils_ipywidgets import check_ipywidget_backend, make_unit_controller + from .utils_ipywidgets import check_ipywidget_backend, UnitSelector check_ipywidget_backend() @@ -147,34 +147,29 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): with output: self.figure = plt.figure(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) plt.show() - if data_plot["unit_ids"] is None: - data_plot["unit_ids"] = [] - unit_widget, unit_controller = make_unit_controller( - data_plot["unit_ids"], list(data_plot["unit_colors"].keys()), ratios[0] * width_cm, height_cm - ) - - self.controller = unit_controller + self.unit_selector = UnitSelector(data_plot["sorting"].unit_ids) + self.unit_selector.value = [ ] - for w in self.controller.values(): - w.observe(self._update_ipywidget) self.widget = widgets.AppLayout( center=self.figure.canvas, - left_sidebar=unit_widget, + left_sidebar=self.unit_selector, pane_widths=ratios + [0], ) # a first update self._update_ipywidget(None) + self.unit_selector.observe(self._update_ipywidget, names='value', type="change") + if backend_kwargs["display"]: display(self.widget) def _update_ipywidget(self, change): from matplotlib.lines import Line2D - unit_ids = self.controller["unit_ids"].value + unit_ids = self.unit_selector.value unit_colors = self.data_plot["unit_colors"] # matplotlib next_data_plot dict update at each call @@ -198,6 +193,7 @@ def _update_ipywidget(self, change): self.plot_matplotlib(self.data_plot, **backend_kwargs) if len(unit_ids) > 0: + # TODO later make option to control legend or not for l in self.figure.legends: l.remove() handles = [ @@ -212,6 +208,7 @@ def _update_ipywidget(self, change): self.figure.canvas.draw() self.figure.canvas.flush_events() + def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv from .utils_sortingview import generate_unit_table_view, make_serializable, handle_display_and_url diff --git a/src/spikeinterface/widgets/spike_locations.py b/src/spikeinterface/widgets/spike_locations.py index 9771b2c0e9..926051b8f9 100644 --- a/src/spikeinterface/widgets/spike_locations.py +++ b/src/spikeinterface/widgets/spike_locations.py @@ -191,7 +191,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt import ipywidgets.widgets as widgets from IPython.display import display - from .utils_ipywidgets import check_ipywidget_backend, make_unit_controller + from .utils_ipywidgets import check_ipywidget_backend, UnitSelector check_ipywidget_backend() @@ -210,48 +210,36 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): fig, self.ax = plt.subplots(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) plt.show() - data_plot["unit_ids"] = data_plot["unit_ids"][:1] - - unit_widget, unit_controller = make_unit_controller( - data_plot["unit_ids"], - list(data_plot["unit_colors"].keys()), - ratios[0] * width_cm, - height_cm, - ) - - self.controller = unit_controller - - for w in self.controller.values(): - w.observe(self._update_ipywidget) + self.unit_selector = UnitSelector(data_plot["unit_ids"]) + self.unit_selector.value = list(data_plot["unit_ids"])[:1] self.widget = widgets.AppLayout( center=fig.canvas, - left_sidebar=unit_widget, + left_sidebar=self.unit_selector, pane_widths=ratios + [0], ) # a first update - self._update_ipywidget(None) + self._update_ipywidget() + + self.unit_selector.observe(self._update_ipywidget, names='value', type="change") if backend_kwargs["display"]: display(self.widget) - def _update_ipywidget(self, change): + def _update_ipywidget(self, change=None): self.ax.clear() - unit_ids = self.controller["unit_ids"].value - # matplotlib next_data_plot dict update at each call data_plot = self.next_data_plot - data_plot["unit_ids"] = unit_ids + data_plot["unit_ids"] = self.unit_selector.value data_plot["plot_all_units"] = True + # TODO add an option checkbox for legend data_plot["plot_legend"] = True data_plot["hide_axis"] = True - backend_kwargs = {} - backend_kwargs["ax"] = self.ax + backend_kwargs = dict(ax=self.ax) - # self.mpl_plotter.do_plot(data_plot, **backend_kwargs) self.plot_matplotlib(data_plot, **backend_kwargs) fig = self.ax.get_figure() fig.canvas.draw() diff --git a/src/spikeinterface/widgets/spikes_on_traces.py b/src/spikeinterface/widgets/spikes_on_traces.py index ae036d1ba1..2f748cc0fc 100644 --- a/src/spikeinterface/widgets/spikes_on_traces.py +++ b/src/spikeinterface/widgets/spikes_on_traces.py @@ -149,20 +149,20 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): sorting = we.sorting # first plot time series - ts_widget = TracesWidget(recording, **dp.options, backend="matplotlib", **backend_kwargs) - self.ax = ts_widget.ax - self.axes = ts_widget.axes - self.figure = ts_widget.figure + traces_widget = TracesWidget(recording, **dp.options, backend="matplotlib", **backend_kwargs) + self.ax = traces_widget.ax + self.axes = traces_widget.axes + self.figure = traces_widget.figure ax = self.ax - frame_range = ts_widget.data_plot["frame_range"] - segment_index = ts_widget.data_plot["segment_index"] - min_y = np.min(ts_widget.data_plot["channel_locations"][:, 1]) - max_y = np.max(ts_widget.data_plot["channel_locations"][:, 1]) + frame_range = traces_widget.data_plot["frame_range"] + segment_index = traces_widget.data_plot["segment_index"] + min_y = np.min(traces_widget.data_plot["channel_locations"][:, 1]) + max_y = np.max(traces_widget.data_plot["channel_locations"][:, 1]) - n = len(ts_widget.data_plot["channel_ids"]) - order = ts_widget.data_plot["order"] + n = len(traces_widget.data_plot["channel_ids"]) + order = traces_widget.data_plot["order"] if order is None: order = np.arange(n) @@ -210,13 +210,13 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): # construct waveforms label_set = False if len(spike_frames_to_plot) > 0: - vspacing = ts_widget.data_plot["vspacing"] - traces = ts_widget.data_plot["list_traces"][0] + vspacing = traces_widget.data_plot["vspacing"] + traces = traces_widget.data_plot["list_traces"][0] waveform_idxs = spike_frames_to_plot[:, None] + np.arange(-we.nbefore, we.nafter) - frame_range[0] - waveform_idxs = np.clip(waveform_idxs, 0, len(ts_widget.data_plot["times"]) - 1) + waveform_idxs = np.clip(waveform_idxs, 0, len(traces_widget.data_plot["times"]) - 1) - times = ts_widget.data_plot["times"][waveform_idxs] + times = traces_widget.data_plot["times"][waveform_idxs] # discontinuity times[:, -1] = np.nan @@ -224,7 +224,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): waveforms = traces[waveform_idxs] # [:, :, order] waveforms_r = waveforms.reshape((waveforms.shape[0] * waveforms.shape[1], waveforms.shape[2])) - for i, chan_id in enumerate(ts_widget.data_plot["channel_ids"]): + for i, chan_id in enumerate(traces_widget.data_plot["channel_ids"]): offset = vspacing * i if chan_id in chan_ids: l = ax.plot(times_r, offset + waveforms_r[:, i], color=dp.unit_colors[unit]) @@ -232,13 +232,13 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): handles.append(l[0]) labels.append(unit) label_set = True - ax.legend(handles, labels) + # ax.legend(handles, labels) def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt import ipywidgets.widgets as widgets from IPython.display import display - from .utils_ipywidgets import check_ipywidget_backend, make_unit_controller + from .utils_ipywidgets import check_ipywidget_backend, UnitSelector check_ipywidget_backend() @@ -256,37 +256,58 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): width_cm = backend_kwargs["width_cm"] # plot timeseries - ts_widget = TracesWidget(we.recording, **dp.options, backend="ipywidgets", **backend_kwargs_ts) - self.ax = ts_widget.ax - self.axes = ts_widget.axes - self.figure = ts_widget.figure + self._traces_widget = TracesWidget(we.recording, **dp.options, backend="ipywidgets", **backend_kwargs_ts) + self.ax = self._traces_widget.ax + self.axes = self._traces_widget.axes + self.figure = self._traces_widget.figure - unit_widget, unit_controller = make_unit_controller( - data_plot["unit_ids"], we.unit_ids, ratios[0] * width_cm, height_cm - ) + self.sampling_frequency = self._traces_widget.rec0.sampling_frequency - self.controller = dict() - self.controller.update(ts_widget.controller) - self.controller.update(unit_controller) + self.time_slider = self._traces_widget.time_slider - for w in self.controller.values(): - w.observe(self._update_ipywidget) + self.unit_selector = UnitSelector(data_plot["unit_ids"]) + self.unit_selector.value = list(data_plot["unit_ids"])[:1] - self.widget = widgets.AppLayout(center=ts_widget.widget, left_sidebar=unit_widget, pane_widths=ratios + [0]) + self.widget = widgets.AppLayout(center=self._traces_widget.widget, + left_sidebar=self.unit_selector, + pane_widths=ratios + [0]) # a first update - self._update_ipywidget(None) + self._update_ipywidget() + + # remove callback from traces_widget + self.unit_selector.observe(self._update_ipywidget, names='value', type="change") + self._traces_widget.time_slider.observe(self._update_ipywidget, names='value', type="change") + self._traces_widget.channel_selector.observe(self._update_ipywidget, names='value', type="change") + self._traces_widget.scaler.observe(self._update_ipywidget, names='value', type="change") + if backend_kwargs["display"]: display(self.widget) - def _update_ipywidget(self, change): + def _update_ipywidget(self, change=None): self.ax.clear() - unit_ids = self.controller["unit_ids"].value + # TODO later: this is still a bit buggy because it make double refresh one from _traces_widget and one internal + + unit_ids = self.unit_selector.value + start_frame, end_frame, segment_index = self._traces_widget.time_slider.value + channel_ids = self._traces_widget.channel_selector.value + mode = self._traces_widget.mode_selector.value data_plot = self.next_data_plot data_plot["unit_ids"] = unit_ids + data_plot["options"].update( + dict( + channel_ids=channel_ids, + segment_index=segment_index, + # frame_range=(start_frame, end_frame), + time_range=np.array([start_frame, end_frame]) / self.sampling_frequency, + mode=mode, + with_colorbar=False, + ) + ) + backend_kwargs = {} backend_kwargs["ax"] = self.ax diff --git a/src/spikeinterface/widgets/unit_locations.py b/src/spikeinterface/widgets/unit_locations.py index 42267e711f..8526a95d60 100644 --- a/src/spikeinterface/widgets/unit_locations.py +++ b/src/spikeinterface/widgets/unit_locations.py @@ -167,7 +167,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt import ipywidgets.widgets as widgets from IPython.display import display - from .utils_ipywidgets import check_ipywidget_backend, make_unit_controller + from .utils_ipywidgets import check_ipywidget_backend, UnitSelector check_ipywidget_backend() @@ -186,42 +186,35 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): fig, self.ax = plt.subplots(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) plt.show() - data_plot["unit_ids"] = data_plot["unit_ids"][:1] - unit_widget, unit_controller = make_unit_controller( - data_plot["unit_ids"], list(data_plot["unit_colors"].keys()), ratios[0] * width_cm, height_cm - ) - - self.controller = unit_controller - - for w in self.controller.values(): - w.observe(self._update_ipywidget) + self.unit_selector = UnitSelector(data_plot["unit_ids"]) + self.unit_selector.value = list(data_plot["unit_ids"])[:1] self.widget = widgets.AppLayout( center=fig.canvas, - left_sidebar=unit_widget, + left_sidebar=self.unit_selector, pane_widths=ratios + [0], ) # a first update - self._update_ipywidget(None) + self._update_ipywidget() + + self.unit_selector.observe(self._update_ipywidget, names='value', type="change") if backend_kwargs["display"]: display(self.widget) - def _update_ipywidget(self, change): + def _update_ipywidget(self, change=None): self.ax.clear() - unit_ids = self.controller["unit_ids"].value - # matplotlib next_data_plot dict update at each call data_plot = self.next_data_plot - data_plot["unit_ids"] = unit_ids + data_plot["unit_ids"] = self.unit_selector.value data_plot["plot_all_units"] = True + # TODO later add an option checkbox for legend data_plot["plot_legend"] = True data_plot["hide_axis"] = True - backend_kwargs = {} - backend_kwargs["ax"] = self.ax + backend_kwargs = dict(ax=self.ax) self.plot_matplotlib(data_plot, **backend_kwargs) fig = self.ax.get_figure() diff --git a/src/spikeinterface/widgets/unit_waveforms.py b/src/spikeinterface/widgets/unit_waveforms.py index e64765b44b..f01c842b66 100644 --- a/src/spikeinterface/widgets/unit_waveforms.py +++ b/src/spikeinterface/widgets/unit_waveforms.py @@ -250,7 +250,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt import ipywidgets.widgets as widgets from IPython.display import display - from .utils_ipywidgets import check_ipywidget_backend, make_unit_controller + from .utils_ipywidgets import check_ipywidget_backend, UnitSelector check_ipywidget_backend() @@ -274,44 +274,33 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.fig_probe, self.ax_probe = plt.subplots(figsize=((ratios[2] * width_cm) * cm, height_cm * cm)) plt.show() - data_plot["unit_ids"] = data_plot["unit_ids"][:1] - unit_widget, unit_controller = make_unit_controller( - data_plot["unit_ids"], we.unit_ids, ratios[0] * width_cm, height_cm - ) + self.unit_selector = UnitSelector(data_plot["unit_ids"]) + self.unit_selector.value = list(data_plot["unit_ids"])[:1] + - same_axis_button = widgets.Checkbox( + self.same_axis_button = widgets.Checkbox( value=False, description="same axis", disabled=False, ) - plot_templates_button = widgets.Checkbox( + self.plot_templates_button = widgets.Checkbox( value=True, description="plot templates", disabled=False, ) - hide_axis_button = widgets.Checkbox( + self.hide_axis_button = widgets.Checkbox( value=True, description="hide axis", disabled=False, ) - footer = widgets.HBox([same_axis_button, plot_templates_button, hide_axis_button]) - - self.controller = { - "same_axis": same_axis_button, - "plot_templates": plot_templates_button, - "hide_axis": hide_axis_button, - } - self.controller.update(unit_controller) - - for w in self.controller.values(): - w.observe(self._update_ipywidget) + footer = widgets.HBox([self.same_axis_button, self.plot_templates_button, self.hide_axis_button]) self.widget = widgets.AppLayout( center=self.fig_wf.canvas, - left_sidebar=unit_widget, + left_sidebar=self.unit_selector, right_sidebar=self.fig_probe.canvas, pane_widths=ratios, footer=footer, @@ -320,6 +309,11 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): # a first update self._update_ipywidget(None) + self.unit_selector.observe(self._update_ipywidget, names='value', type="change") + for w in self.same_axis_button, self.plot_templates_button, self.hide_axis_button: + w.observe(self._update_ipywidget, names='value', type="change") + + if backend_kwargs["display"]: display(self.widget) @@ -327,10 +321,15 @@ def _update_ipywidget(self, change): self.fig_wf.clear() self.ax_probe.clear() - unit_ids = self.controller["unit_ids"].value - same_axis = self.controller["same_axis"].value - plot_templates = self.controller["plot_templates"].value - hide_axis = self.controller["hide_axis"].value + # unit_ids = self.controller["unit_ids"].value + unit_ids = self.unit_selector.value + # same_axis = self.controller["same_axis"].value + # plot_templates = self.controller["plot_templates"].value + # hide_axis = self.controller["hide_axis"].value + + same_axis = self.same_axis_button.value + plot_templates = self.plot_templates_button.value + hide_axis = self.hide_axis_button.value # matplotlib next_data_plot dict update at each call data_plot = self.next_data_plot @@ -341,6 +340,8 @@ def _update_ipywidget(self, change): data_plot["plot_templates"] = plot_templates if data_plot["plot_waveforms"]: data_plot["wfs_by_ids"] = {unit_id: self.we.get_waveforms(unit_id) for unit_id in unit_ids} + + # TODO option for plot_legend backend_kwargs = {} @@ -369,6 +370,7 @@ def _update_ipywidget(self, change): self.ax_probe.axis("off") self.ax_probe.axis("equal") + # TODO this could be done with probeinterface plotting plotting tools!! for unit in unit_ids: channel_inds = data_plot["sparsity"].unit_id_to_channel_indices[unit] self.ax_probe.plot( diff --git a/src/spikeinterface/widgets/utils_ipywidgets.py b/src/spikeinterface/widgets/utils_ipywidgets.py index d2c41f234a..57550c0910 100644 --- a/src/spikeinterface/widgets/utils_ipywidgets.py +++ b/src/spikeinterface/widgets/utils_ipywidgets.py @@ -11,102 +11,6 @@ def check_ipywidget_backend(): assert "ipympl" in mpl_backend, "To use the 'ipywidgets' backend, you have to set %matplotlib widget" -# def make_timeseries_controller(t_start, t_stop, layer_keys, num_segments, time_range, mode, all_layers, width_cm): -# time_slider = W.FloatSlider( -# orientation="horizontal", -# description="time:", -# value=time_range[0], -# min=t_start, -# max=t_stop, -# continuous_update=False, -# layout=W.Layout(width=f"{width_cm}cm"), -# ) -# layer_selector = W.Dropdown(description="layer", options=layer_keys) -# segment_selector = W.Dropdown(description="segment", options=list(range(num_segments))) -# window_sizer = W.BoundedFloatText(value=np.diff(time_range)[0], step=0.1, min=0.005, description="win (s)") -# mode_selector = W.Dropdown(options=["line", "map"], description="mode", value=mode) -# all_layers = W.Checkbox(description="plot all layers", value=all_layers) - -# controller = { -# "layer_key": layer_selector, -# "segment_index": segment_selector, -# "window": window_sizer, -# "t_start": time_slider, -# "mode": mode_selector, -# "all_layers": all_layers, -# } -# widget = W.VBox( -# [time_slider, W.HBox([all_layers, layer_selector, segment_selector, window_sizer, mode_selector])] -# ) - -# return widget, controller - - -def make_unit_controller(unit_ids, all_unit_ids, width_cm, height_cm): - unit_label = W.Label(value="units:") - - unit_selector = W.SelectMultiple( - options=all_unit_ids, - value=list(unit_ids), - disabled=False, - layout=W.Layout(width=f"{width_cm}cm", height=f"{height_cm}cm"), - ) - - controller = {"unit_ids": unit_selector} - widget = W.VBox([unit_label, unit_selector]) - - return widget, controller - - -# def make_channel_controller(recording, width_cm, height_cm): -# channel_label = W.Label("channel indices:", layout=W.Layout(justify_content="center")) -# channel_selector = W.IntRangeSlider( -# value=[0, recording.get_num_channels()], -# min=0, -# max=recording.get_num_channels(), -# step=1, -# disabled=False, -# continuous_update=False, -# orientation="vertical", -# readout=True, -# readout_format="d", -# layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{height_cm}cm"), -# ) - -# controller = {"channel_inds": channel_selector} -# widget = W.VBox([channel_label, channel_selector]) - -# return widget, controller - - -# def make_scale_controller(width_cm, height_cm): -# scale_label = W.Label("Scale", layout=W.Layout(justify_content="center")) - -# plus_selector = W.Button( -# description="", -# disabled=False, -# button_style="", # 'success', 'info', 'warning', 'danger' or '' -# tooltip="Increase scale", -# icon="arrow-up", -# layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), -# ) - -# minus_selector = W.Button( -# description="", -# disabled=False, -# button_style="", # 'success', 'info', 'warning', 'danger' or '' -# tooltip="Decrease scale", -# icon="arrow-down", -# layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), -# ) - -# controller = {"plus": plus_selector, "minus": minus_selector} -# widget = W.VBox([scale_label, plus_selector, minus_selector]) - -# return widget, controller - - - class TimeSlider(W.HBox): value = traitlets.Tuple(traitlets.Int(), traitlets.Int(), traitlets.Int()) From f315594b0b88bed01f01232688d62c4c2e4bc0fe Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 22 Sep 2023 15:49:47 +0200 Subject: [PATCH 7/8] protect TimeSlider on the upper limit to avoid border effect on window size --- src/spikeinterface/widgets/utils_ipywidgets.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/widgets/utils_ipywidgets.py b/src/spikeinterface/widgets/utils_ipywidgets.py index 57550c0910..ee6133a990 100644 --- a/src/spikeinterface/widgets/utils_ipywidgets.py +++ b/src/spikeinterface/widgets/utils_ipywidgets.py @@ -54,7 +54,7 @@ def __init__(self, durations, sampling_frequency, time_range=(0, 1.), **kwargs): # description='time:', value=start_frame, min=0, - max=self.frame_limits[self.segment_index], + max=self.frame_limits[self.segment_index] - 1, readout=False, continuous_update=False, layout=W.Layout(width=f'70%') @@ -112,10 +112,13 @@ def update_time(self, new_frame=None, new_time=None, update_slider=False, update else: start_frame = new_frame delta_s = self.window_sizer.value - end_frame = start_frame + int(delta_s * self.sampling_frequency) - + delta = int(delta_s * self.sampling_frequency) + # clip + start_frame = min(self.frame_limits[self.segment_index] - delta, start_frame) start_frame = max(0, start_frame) + end_frame = start_frame + delta + end_frame = min(self.frame_limits[self.segment_index], end_frame) @@ -170,7 +173,7 @@ def segment_changed(self, change=None): self.slider.unobserve(self.slider_moved, names='value', type="change") # self.slider.value = 0 - self.slider.max = self.frame_limits[self.segment_index] + self.slider.max = self.frame_limits[self.segment_index] - 1 self.slider.observe(self.slider_moved, names='value', type="change") self.update_time(new_frame=0, update_slider=True, update_label=True) From 2c015f78e9311e106e9d2fda4e4026a61ca68c5b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 27 Sep 2023 09:28:28 +0000 Subject: [PATCH 8/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/widgets/amplitudes.py | 7 +- src/spikeinterface/widgets/base.py | 2 +- src/spikeinterface/widgets/metrics.py | 6 +- src/spikeinterface/widgets/spike_locations.py | 2 +- .../widgets/spikes_on_traces.py | 20 +- src/spikeinterface/widgets/traces.py | 51 ++-- src/spikeinterface/widgets/unit_locations.py | 2 +- src/spikeinterface/widgets/unit_waveforms.py | 8 +- .../widgets/utils_ipywidgets.py | 222 +++++++++--------- 9 files changed, 163 insertions(+), 157 deletions(-) diff --git a/src/spikeinterface/widgets/amplitudes.py b/src/spikeinterface/widgets/amplitudes.py index 5aa090b1b4..6b6496a577 100644 --- a/src/spikeinterface/widgets/amplitudes.py +++ b/src/spikeinterface/widgets/amplitudes.py @@ -174,6 +174,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt + # import ipywidgets.widgets as widgets import ipywidgets.widgets as W from IPython.display import display @@ -210,7 +211,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.unit_selector, self.checkbox_histograms, ], - layout = W.Layout(align_items="center", width="4cm", height="100%"), + layout=W.Layout(align_items="center", width="4cm", height="100%"), ) self.widget = W.AppLayout( @@ -222,8 +223,8 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): # a first update self._full_update_plot() - self.unit_selector.observe(self._update_plot, names='value', type="change") - self.checkbox_histograms.observe(self._full_update_plot, names='value', type="change") + self.unit_selector.observe(self._update_plot, names="value", type="change") + self.checkbox_histograms.observe(self._full_update_plot, names="value", type="change") if backend_kwargs["display"]: display(self.widget) diff --git a/src/spikeinterface/widgets/base.py b/src/spikeinterface/widgets/base.py index 1ff691320a..9fc7b73707 100644 --- a/src/spikeinterface/widgets/base.py +++ b/src/spikeinterface/widgets/base.py @@ -38,7 +38,7 @@ def set_default_plotter_backend(backend): "width_cm": "Width of the figure in cm (default 10)", "height_cm": "Height of the figure in cm (default 6)", "display": "If True, widgets are immediately displayed", - # "controllers": "" + # "controllers": "" }, "ephyviewer": {}, } diff --git a/src/spikeinterface/widgets/metrics.py b/src/spikeinterface/widgets/metrics.py index 604da35e65..c7b701c8b0 100644 --- a/src/spikeinterface/widgets/metrics.py +++ b/src/spikeinterface/widgets/metrics.py @@ -149,8 +149,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): plt.show() self.unit_selector = UnitSelector(data_plot["sorting"].unit_ids) - self.unit_selector.value = [ ] - + self.unit_selector.value = [] self.widget = widgets.AppLayout( center=self.figure.canvas, @@ -161,7 +160,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): # a first update self._update_ipywidget(None) - self.unit_selector.observe(self._update_ipywidget, names='value', type="change") + self.unit_selector.observe(self._update_ipywidget, names="value", type="change") if backend_kwargs["display"]: display(self.widget) @@ -208,7 +207,6 @@ def _update_ipywidget(self, change): self.figure.canvas.draw() self.figure.canvas.flush_events() - def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv from .utils_sortingview import generate_unit_table_view, make_serializable, handle_display_and_url diff --git a/src/spikeinterface/widgets/spike_locations.py b/src/spikeinterface/widgets/spike_locations.py index 926051b8f9..fda2356105 100644 --- a/src/spikeinterface/widgets/spike_locations.py +++ b/src/spikeinterface/widgets/spike_locations.py @@ -222,7 +222,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): # a first update self._update_ipywidget() - self.unit_selector.observe(self._update_ipywidget, names='value', type="change") + self.unit_selector.observe(self._update_ipywidget, names="value", type="change") if backend_kwargs["display"]: display(self.widget) diff --git a/src/spikeinterface/widgets/spikes_on_traces.py b/src/spikeinterface/widgets/spikes_on_traces.py index 2f748cc0fc..c2bed8fe41 100644 --- a/src/spikeinterface/widgets/spikes_on_traces.py +++ b/src/spikeinterface/widgets/spikes_on_traces.py @@ -232,7 +232,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): handles.append(l[0]) labels.append(unit) label_set = True - # ax.legend(handles, labels) + # ax.legend(handles, labels) def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt @@ -268,19 +268,18 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.unit_selector = UnitSelector(data_plot["unit_ids"]) self.unit_selector.value = list(data_plot["unit_ids"])[:1] - self.widget = widgets.AppLayout(center=self._traces_widget.widget, - left_sidebar=self.unit_selector, - pane_widths=ratios + [0]) + self.widget = widgets.AppLayout( + center=self._traces_widget.widget, left_sidebar=self.unit_selector, pane_widths=ratios + [0] + ) # a first update self._update_ipywidget() # remove callback from traces_widget - self.unit_selector.observe(self._update_ipywidget, names='value', type="change") - self._traces_widget.time_slider.observe(self._update_ipywidget, names='value', type="change") - self._traces_widget.channel_selector.observe(self._update_ipywidget, names='value', type="change") - self._traces_widget.scaler.observe(self._update_ipywidget, names='value', type="change") - + self.unit_selector.observe(self._update_ipywidget, names="value", type="change") + self._traces_widget.time_slider.observe(self._update_ipywidget, names="value", type="change") + self._traces_widget.channel_selector.observe(self._update_ipywidget, names="value", type="change") + self._traces_widget.scaler.observe(self._update_ipywidget, names="value", type="change") if backend_kwargs["display"]: display(self.widget) @@ -305,10 +304,9 @@ def _update_ipywidget(self, change=None): time_range=np.array([start_frame, end_frame]) / self.sampling_frequency, mode=mode, with_colorbar=False, - ) + ) ) - backend_kwargs = {} backend_kwargs["ax"] = self.ax diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index d107c5cb23..9b6716e8f3 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -290,7 +290,6 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): check_ipywidget_backend() self.next_data_plot = data_plot.copy() - self.recordings = data_plot["recordings"] @@ -314,7 +313,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.time_slider = TimeSlider( durations=[rec0.get_duration(s) for s in range(rec0.get_num_segments())], sampling_frequency=rec0.sampling_frequency, - # layout=W.Layout(height="2cm"), + # layout=W.Layout(height="2cm"), ) start_frame = int(data_plot["time_range"][0] * rec0.sampling_frequency) @@ -324,14 +323,17 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): _layer_keys = data_plot["layer_keys"] if len(_layer_keys) > 1: - _layer_keys = ['ALL'] + _layer_keys - self.layer_selector = W.Dropdown(options=_layer_keys, - layout=W.Layout(width="95%"), - ) - self.mode_selector = W.Dropdown(options=["line", "map"], value=data_plot["mode"], - # layout=W.Layout(width="5cm"), - layout=W.Layout(width="95%"), - ) + _layer_keys = ["ALL"] + _layer_keys + self.layer_selector = W.Dropdown( + options=_layer_keys, + layout=W.Layout(width="95%"), + ) + self.mode_selector = W.Dropdown( + options=["line", "map"], + value=data_plot["mode"], + # layout=W.Layout(width="5cm"), + layout=W.Layout(width="95%"), + ) self.scaler = ScaleWidget() self.channel_selector = ChannelSelector(self.rec0.channel_ids) @@ -343,9 +345,9 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.mode_selector, self.scaler, # self.channel_selector, - ], + ], layout=W.Layout(width="3.5cm"), - align_items='center', + align_items="center", ) self.return_scaled = data_plot["return_scaled"] @@ -353,7 +355,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.widget = widgets.AppLayout( center=self.figure.canvas, footer=self.time_slider, - left_sidebar = left_sidebar, + left_sidebar=left_sidebar, right_sidebar=self.channel_selector, pane_heights=[0, 6, 1], pane_widths=ratios, @@ -365,28 +367,28 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): # callbacks: # some widgets generate a full retrieve + refresh - self.time_slider.observe(self._retrieve_traces, names='value', type="change") - self.layer_selector.observe(self._retrieve_traces, names='value', type="change") - self.channel_selector.observe(self._retrieve_traces, names='value', type="change") + self.time_slider.observe(self._retrieve_traces, names="value", type="change") + self.layer_selector.observe(self._retrieve_traces, names="value", type="change") + self.channel_selector.observe(self._retrieve_traces, names="value", type="change") # other widgets only refresh - self.scaler.observe(self._update_plot, names='value', type="change") + self.scaler.observe(self._update_plot, names="value", type="change") # map is a special case because needs to check layer also - self.mode_selector.observe(self._mode_changed, names='value', type="change") - + self.mode_selector.observe(self._mode_changed, names="value", type="change") + if backend_kwargs["display"]: # self.check_backend() display(self.widget) def _get_layers(self): layer = self.layer_selector.value - if layer == 'ALL': + if layer == "ALL": layer_keys = self.data_plot["layer_keys"] else: layer_keys = [layer] if self.mode_selector.value == "map": layer_keys = layer_keys[:1] return layer_keys - + def _mode_changed(self, change=None): if self.mode_selector.value == "map" and self.layer_selector.value == "ALL": self.layer_selector.value = self.data_plot["layer_keys"][0] @@ -400,7 +402,7 @@ def _retrieve_traces(self, change=None): order, _ = order_channels_by_depth(self.rec0, channel_ids) else: order = None - + start_frame, end_frame, segment_index = self.time_slider.value time_range = np.array([start_frame, end_frame]) / self.rec0.sampling_frequency @@ -439,9 +441,9 @@ def _update_plot(self, change=None): data_plot["clims"] = clims data_plot["channel_ids"] = self._channel_ids - + data_plot["layer_keys"] = layer_keys - data_plot["colors"] = {k:self.data_plot["colors"][k] for k in layer_keys} + data_plot["colors"] = {k: self.data_plot["colors"][k] for k in layer_keys} list_traces = [traces * self.scaler.value for traces in self._list_traces] data_plot["list_traces"] = list_traces @@ -458,7 +460,6 @@ def _update_plot(self, change=None): fig.canvas.draw() fig.canvas.flush_events() - def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv from .utils_sortingview import handle_display_and_url diff --git a/src/spikeinterface/widgets/unit_locations.py b/src/spikeinterface/widgets/unit_locations.py index 8526a95d60..b41ee3508b 100644 --- a/src/spikeinterface/widgets/unit_locations.py +++ b/src/spikeinterface/widgets/unit_locations.py @@ -198,7 +198,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): # a first update self._update_ipywidget() - self.unit_selector.observe(self._update_ipywidget, names='value', type="change") + self.unit_selector.observe(self._update_ipywidget, names="value", type="change") if backend_kwargs["display"]: display(self.widget) diff --git a/src/spikeinterface/widgets/unit_waveforms.py b/src/spikeinterface/widgets/unit_waveforms.py index f01c842b66..8ffc931bf2 100644 --- a/src/spikeinterface/widgets/unit_waveforms.py +++ b/src/spikeinterface/widgets/unit_waveforms.py @@ -277,7 +277,6 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.unit_selector = UnitSelector(data_plot["unit_ids"]) self.unit_selector.value = list(data_plot["unit_ids"])[:1] - self.same_axis_button = widgets.Checkbox( value=False, description="same axis", @@ -309,10 +308,9 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): # a first update self._update_ipywidget(None) - self.unit_selector.observe(self._update_ipywidget, names='value', type="change") + self.unit_selector.observe(self._update_ipywidget, names="value", type="change") for w in self.same_axis_button, self.plot_templates_button, self.hide_axis_button: - w.observe(self._update_ipywidget, names='value', type="change") - + w.observe(self._update_ipywidget, names="value", type="change") if backend_kwargs["display"]: display(self.widget) @@ -340,7 +338,7 @@ def _update_ipywidget(self, change): data_plot["plot_templates"] = plot_templates if data_plot["plot_waveforms"]: data_plot["wfs_by_ids"] = {unit_id: self.we.get_waveforms(unit_id) for unit_id in unit_ids} - + # TODO option for plot_legend backend_kwargs = {} diff --git a/src/spikeinterface/widgets/utils_ipywidgets.py b/src/spikeinterface/widgets/utils_ipywidgets.py index ee6133a990..6e872eca55 100644 --- a/src/spikeinterface/widgets/utils_ipywidgets.py +++ b/src/spikeinterface/widgets/utils_ipywidgets.py @@ -12,12 +12,9 @@ def check_ipywidget_backend(): class TimeSlider(W.HBox): - value = traitlets.Tuple(traitlets.Int(), traitlets.Int(), traitlets.Int()) - - def __init__(self, durations, sampling_frequency, time_range=(0, 1.), **kwargs): - - + + def __init__(self, durations, sampling_frequency, time_range=(0, 1.0), **kwargs): self.num_segments = len(durations) self.frame_limits = [int(sampling_frequency * d) for d in durations] self.sampling_frequency = sampling_frequency @@ -28,81 +25,100 @@ def __init__(self, durations, sampling_frequency, time_range=(0, 1.), **kwargs): self.segment_index = 0 self.value = (start_frame, end_frame, self.segment_index) - - + layout = W.Layout(align_items="center", width="2.5cm", height="1.cm") - but_left = W.Button(description='', disabled=False, button_style='', icon='arrow-left', layout=layout) - but_right = W.Button(description='', disabled=False, button_style='', icon='arrow-right', layout=layout) - + but_left = W.Button(description="", disabled=False, button_style="", icon="arrow-left", layout=layout) + but_right = W.Button(description="", disabled=False, button_style="", icon="arrow-right", layout=layout) + but_left.on_click(self.move_left) but_right.on_click(self.move_right) - self.move_size = W.Dropdown(options=['10 ms', '100 ms', '1 s', '10 s', '1 m', '30 m', '1 h',], # '6 h', '24 h' - value='1 s', - description='', - layout = W.Layout(width="2cm") - ) + self.move_size = W.Dropdown( + options=[ + "10 ms", + "100 ms", + "1 s", + "10 s", + "1 m", + "30 m", + "1 h", + ], # '6 h', '24 h' + value="1 s", + description="", + layout=W.Layout(width="2cm"), + ) # DatetimePicker is only for ipywidget v8 (which is not working in vscode 2023-03) - self.time_label = W.Text(value=f'{time_range[0]}',description='', - disabled=False, layout=W.Layout(width='2.5cm')) - self.time_label.observe(self.time_label_changed, names='value', type="change") - + self.time_label = W.Text( + value=f"{time_range[0]}", description="", disabled=False, layout=W.Layout(width="2.5cm") + ) + self.time_label.observe(self.time_label_changed, names="value", type="change") self.slider = W.IntSlider( - orientation='horizontal', - # description='time:', + orientation="horizontal", + # description='time:', value=start_frame, min=0, max=self.frame_limits[self.segment_index] - 1, readout=False, continuous_update=False, - layout=W.Layout(width=f'70%') + layout=W.Layout(width=f"70%"), ) - - self.slider.observe(self.slider_moved, names='value', type="change") - + + self.slider.observe(self.slider_moved, names="value", type="change") + delta_s = np.diff(self.frame_range) / sampling_frequency - - self.window_sizer = W.BoundedFloatText(value=delta_s, step=1, - min=0.01, max=30., - description='win (s)', - layout=W.Layout(width='auto') - # layout=W.Layout(width=f'10%') - ) - self.window_sizer.observe(self.win_size_changed, names='value', type="change") + + self.window_sizer = W.BoundedFloatText( + value=delta_s, + step=1, + min=0.01, + max=30.0, + description="win (s)", + layout=W.Layout(width="auto") + # layout=W.Layout(width=f'10%') + ) + self.window_sizer.observe(self.win_size_changed, names="value", type="change") self.segment_selector = W.Dropdown(description="segment", options=list(range(self.num_segments))) - self.segment_selector.observe(self.segment_changed, names='value', type="change") + self.segment_selector.observe(self.segment_changed, names="value", type="change") + + super(W.HBox, self).__init__( + children=[ + self.segment_selector, + but_left, + self.move_size, + but_right, + self.slider, + self.time_label, + self.window_sizer, + ], + layout=W.Layout(align_items="center", width="100%", height="100%"), + **kwargs, + ) - super(W.HBox, self).__init__(children=[self.segment_selector, but_left, self.move_size, but_right, - self.slider, self.time_label, self.window_sizer], - layout=W.Layout(align_items="center", width="100%", height="100%"), - **kwargs) - - self.observe(self.value_changed, names=['value'], type="change") + self.observe(self.value_changed, names=["value"], type="change") def value_changed(self, change=None): - - self.unobserve(self.value_changed, names=['value'], type="change") + self.unobserve(self.value_changed, names=["value"], type="change") start, stop, seg_index = self.value if seg_index < 0 or seg_index >= self.num_segments: - self.value = change['old'] + self.value = change["old"] return if start < 0 or stop < 0: - self.value = change['old'] + self.value = change["old"] return if start >= self.frame_limits[seg_index] or start > self.frame_limits[seg_index]: - self.value = change['old'] + self.value = change["old"] return - + self.segment_selector.value = seg_index self.update_time(new_frame=start, update_slider=True, update_label=True) delta_s = (stop - start) / self.sampling_frequency self.window_sizer.value = delta_s - self.observe(self.value_changed, names=['value'], type="change") + self.observe(self.value_changed, names=["value"], type="change") def update_time(self, new_frame=None, new_time=None, update_slider=False, update_label=False): if new_frame is None and new_time is None: @@ -118,25 +134,24 @@ def update_time(self, new_frame=None, new_time=None, update_slider=False, update start_frame = min(self.frame_limits[self.segment_index] - delta, start_frame) start_frame = max(0, start_frame) end_frame = start_frame + delta - + end_frame = min(self.frame_limits[self.segment_index], end_frame) - start_time = start_frame / self.sampling_frequency if update_label: - self.time_label.unobserve(self.time_label_changed, names='value', type="change") - self.time_label.value = f'{start_time}' - self.time_label.observe(self.time_label_changed, names='value', type="change") + self.time_label.unobserve(self.time_label_changed, names="value", type="change") + self.time_label.value = f"{start_time}" + self.time_label.observe(self.time_label_changed, names="value", type="change") if update_slider: - self.slider.unobserve(self.slider_moved, names='value', type="change") + self.slider.unobserve(self.slider_moved, names="value", type="change") self.slider.value = start_frame - self.slider.observe(self.slider_moved, names='value', type="change") - + self.slider.observe(self.slider_moved, names="value", type="change") + self.frame_range = (start_frame, end_frame) self.value = (start_frame, end_frame, self.segment_index) - + def time_label_changed(self, change=None): try: new_time = float(self.time_label.value) @@ -145,39 +160,39 @@ def time_label_changed(self, change=None): if new_time is not None: self.update_time(new_time=new_time, update_slider=True) - def win_size_changed(self, change=None): self.update_time() - + def slider_moved(self, change=None): new_frame = self.slider.value self.update_time(new_frame=new_frame, update_label=True) - + def move(self, sign): - value, units = self.move_size.value.split(' ') + value, units = self.move_size.value.split(" ") value = int(value) - delta_s = (sign * np.timedelta64(value, units)) / np.timedelta64(1, 's') + delta_s = (sign * np.timedelta64(value, units)) / np.timedelta64(1, "s") delta_sample = int(delta_s * self.sampling_frequency) new_frame = self.frame_range[0] + delta_sample self.slider.value = new_frame - + def move_left(self, change=None): self.move(-1) def move_right(self, change=None): self.move(+1) - + def segment_changed(self, change=None): self.segment_index = self.segment_selector.value - self.slider.unobserve(self.slider_moved, names='value', type="change") + self.slider.unobserve(self.slider_moved, names="value", type="change") # self.slider.value = 0 self.slider.max = self.frame_limits[self.segment_index] - 1 - self.slider.observe(self.slider_moved, names='value', type="change") + self.slider.observe(self.slider_moved, names="value", type="change") self.update_time(new_frame=0, update_slider=True, update_label=True) + class ChannelSelector(W.VBox): value = traitlets.List() @@ -211,22 +226,24 @@ def __init__(self, channel_ids, **kwargs): ) hbox = W.HBox(children=[self.slider, self.selector]) - super(W.VBox, self).__init__(children=[channel_label, hbox], - layout=W.Layout(align_items="center"), - # layout=W.Layout(align_items="center", width="100%", height="100%"), - **kwargs) - self.slider.observe(self.on_slider_changed, names=['value'], type="change") - self.selector.observe(self.on_selector_changed, names=['value'], type="change") + super(W.VBox, self).__init__( + children=[channel_label, hbox], + layout=W.Layout(align_items="center"), + # layout=W.Layout(align_items="center", width="100%", height="100%"), + **kwargs, + ) + self.slider.observe(self.on_slider_changed, names=["value"], type="change") + self.selector.observe(self.on_selector_changed, names=["value"], type="change") # TODO external value change # self.observe(self.value_changed, names=['value'], type="change") - + def on_slider_changed(self, change=None): i0, i1 = self.slider.value - - self.selector.unobserve(self.on_selector_changed, names=['value'], type="change") + + self.selector.unobserve(self.on_selector_changed, names=["value"], type="change") self.selector.value = self.channel_ids[i0:i1][::-1] - self.selector.observe(self.on_selector_changed, names=['value'], type="change") + self.selector.observe(self.on_selector_changed, names=["value"], type="change") self.value = self.channel_ids[i0:i1] @@ -235,27 +252,23 @@ def on_selector_changed(self, change=None): channel_ids = channel_ids[::-1] if len(channel_ids) > 0: - self.slider.unobserve(self.on_slider_changed, names=['value'], type="change") + self.slider.unobserve(self.on_slider_changed, names=["value"], type="change") i0 = self.channel_ids.index(channel_ids[0]) i1 = self.channel_ids.index(channel_ids[-1]) + 1 self.slider.value = (i0, i1) - self.slider.observe(self.on_slider_changed, names=['value'], type="change") + self.slider.observe(self.on_slider_changed, names=["value"], type="change") self.value = channel_ids - class ScaleWidget(W.VBox): value = traitlets.Float() - def __init__(self, value=1., factor=1.2, **kwargs): - - assert factor > 1. + def __init__(self, value=1.0, factor=1.2, **kwargs): + assert factor > 1.0 self.factor = factor - self.scale_label = W.Label("Scale", - layout=W.Layout(layout=W.Layout(width='95%'), - justify_content="center")) + self.scale_label = W.Label("Scale", layout=W.Layout(layout=W.Layout(width="95%"), justify_content="center")) self.plus_selector = W.Button( description="", @@ -264,7 +277,7 @@ def __init__(self, value=1., factor=1.2, **kwargs): tooltip="Increase scale", icon="arrow-up", # layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), - layout=W.Layout(width='60%', align_self='center'), + layout=W.Layout(width="60%", align_self="center"), ) self.minus_selector = W.Button( @@ -274,31 +287,31 @@ def __init__(self, value=1., factor=1.2, **kwargs): tooltip="Decrease scale", icon="arrow-down", # layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), - layout=W.Layout(width='60%', align_self='center'), + layout=W.Layout(width="60%", align_self="center"), ) self.plus_selector.on_click(self.plus_clicked) self.minus_selector.on_click(self.minus_clicked) - self.value = 1. - super(W.VBox, self).__init__(children=[self.plus_selector, self.scale_label, self.minus_selector], - # layout=W.Layout(align_items="center", width="100%", height="100%"), - **kwargs) + self.value = 1.0 + super(W.VBox, self).__init__( + children=[self.plus_selector, self.scale_label, self.minus_selector], + # layout=W.Layout(align_items="center", width="100%", height="100%"), + **kwargs, + ) self.update_label() - self.observe(self.value_changed, names=['value'], type="change") - + self.observe(self.value_changed, names=["value"], type="change") + def update_label(self): self.scale_label.value = f"Scale: {self.value:0.2f}" - def plus_clicked(self, change=None): self.value = self.value * self.factor def minus_clicked(self, change=None): self.value = self.value / self.factor - def value_changed(self, change=None): self.update_label() @@ -319,20 +332,17 @@ def __init__(self, unit_ids, **kwargs): layout=W.Layout(height="100%", width="2cm"), ) - super(W.VBox, self).__init__(children=[label, self.selector], - layout=W.Layout(align_items="center"), - **kwargs) - - self.selector.observe(self.on_selector_changed, names=['value'], type="change") + super(W.VBox, self).__init__(children=[label, self.selector], layout=W.Layout(align_items="center"), **kwargs) + + self.selector.observe(self.on_selector_changed, names=["value"], type="change") + + self.observe(self.value_changed, names=["value"], type="change") - self.observe(self.value_changed, names=['value'], type="change") - def on_selector_changed(self, change=None): unit_ids = self.selector.value self.value = unit_ids - - def value_changed(self, change=None): - self.selector.unobserve(self.on_selector_changed, names=['value'], type="change") - self.selector.value = change['new'] - self.selector.observe(self.on_selector_changed, names=['value'], type="change") + def value_changed(self, change=None): + self.selector.unobserve(self.on_selector_changed, names=["value"], type="change") + self.selector.value = change["new"] + self.selector.observe(self.on_selector_changed, names=["value"], type="change")