Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Sep 27, 2023
1 parent 39090d8 commit 2c015f7
Show file tree
Hide file tree
Showing 9 changed files with 163 additions and 157 deletions.
7 changes: 4 additions & 3 deletions src/spikeinterface/widgets/amplitudes.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):

def plot_ipywidgets(self, data_plot, **backend_kwargs):
import matplotlib.pyplot as plt

# import ipywidgets.widgets as widgets
import ipywidgets.widgets as W
from IPython.display import display
Expand Down Expand Up @@ -210,7 +211,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):
self.unit_selector,
self.checkbox_histograms,
],
layout = W.Layout(align_items="center", width="4cm", height="100%"),
layout=W.Layout(align_items="center", width="4cm", height="100%"),
)

self.widget = W.AppLayout(
Expand All @@ -222,8 +223,8 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):
# a first update
self._full_update_plot()

self.unit_selector.observe(self._update_plot, names='value', type="change")
self.checkbox_histograms.observe(self._full_update_plot, names='value', type="change")
self.unit_selector.observe(self._update_plot, names="value", type="change")
self.checkbox_histograms.observe(self._full_update_plot, names="value", type="change")

if backend_kwargs["display"]:
display(self.widget)
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/widgets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def set_default_plotter_backend(backend):
"width_cm": "Width of the figure in cm (default 10)",
"height_cm": "Height of the figure in cm (default 6)",
"display": "If True, widgets are immediately displayed",
# "controllers": ""
# "controllers": ""
},
"ephyviewer": {},
}
Expand Down
6 changes: 2 additions & 4 deletions src/spikeinterface/widgets/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):
plt.show()

self.unit_selector = UnitSelector(data_plot["sorting"].unit_ids)
self.unit_selector.value = [ ]

self.unit_selector.value = []

self.widget = widgets.AppLayout(
center=self.figure.canvas,
Expand All @@ -161,7 +160,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):
# a first update
self._update_ipywidget(None)

self.unit_selector.observe(self._update_ipywidget, names='value', type="change")
self.unit_selector.observe(self._update_ipywidget, names="value", type="change")

if backend_kwargs["display"]:
display(self.widget)
Expand Down Expand Up @@ -208,7 +207,6 @@ def _update_ipywidget(self, change):
self.figure.canvas.draw()
self.figure.canvas.flush_events()


def plot_sortingview(self, data_plot, **backend_kwargs):
import sortingview.views as vv
from .utils_sortingview import generate_unit_table_view, make_serializable, handle_display_and_url
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/widgets/spike_locations.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):
# a first update
self._update_ipywidget()

self.unit_selector.observe(self._update_ipywidget, names='value', type="change")
self.unit_selector.observe(self._update_ipywidget, names="value", type="change")

if backend_kwargs["display"]:
display(self.widget)
Expand Down
20 changes: 9 additions & 11 deletions src/spikeinterface/widgets/spikes_on_traces.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
handles.append(l[0])
labels.append(unit)
label_set = True
# ax.legend(handles, labels)
# ax.legend(handles, labels)

def plot_ipywidgets(self, data_plot, **backend_kwargs):
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -268,19 +268,18 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):
self.unit_selector = UnitSelector(data_plot["unit_ids"])
self.unit_selector.value = list(data_plot["unit_ids"])[:1]

self.widget = widgets.AppLayout(center=self._traces_widget.widget,
left_sidebar=self.unit_selector,
pane_widths=ratios + [0])
self.widget = widgets.AppLayout(
center=self._traces_widget.widget, left_sidebar=self.unit_selector, pane_widths=ratios + [0]
)

# a first update
self._update_ipywidget()

# remove callback from traces_widget
self.unit_selector.observe(self._update_ipywidget, names='value', type="change")
self._traces_widget.time_slider.observe(self._update_ipywidget, names='value', type="change")
self._traces_widget.channel_selector.observe(self._update_ipywidget, names='value', type="change")
self._traces_widget.scaler.observe(self._update_ipywidget, names='value', type="change")

self.unit_selector.observe(self._update_ipywidget, names="value", type="change")
self._traces_widget.time_slider.observe(self._update_ipywidget, names="value", type="change")
self._traces_widget.channel_selector.observe(self._update_ipywidget, names="value", type="change")
self._traces_widget.scaler.observe(self._update_ipywidget, names="value", type="change")

if backend_kwargs["display"]:
display(self.widget)
Expand All @@ -305,10 +304,9 @@ def _update_ipywidget(self, change=None):
time_range=np.array([start_frame, end_frame]) / self.sampling_frequency,
mode=mode,
with_colorbar=False,
)
)
)


backend_kwargs = {}
backend_kwargs["ax"] = self.ax

Expand Down
51 changes: 26 additions & 25 deletions src/spikeinterface/widgets/traces.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,6 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):
check_ipywidget_backend()

self.next_data_plot = data_plot.copy()


self.recordings = data_plot["recordings"]

Expand All @@ -314,7 +313,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):
self.time_slider = TimeSlider(
durations=[rec0.get_duration(s) for s in range(rec0.get_num_segments())],
sampling_frequency=rec0.sampling_frequency,
# layout=W.Layout(height="2cm"),
# layout=W.Layout(height="2cm"),
)

start_frame = int(data_plot["time_range"][0] * rec0.sampling_frequency)
Expand All @@ -324,14 +323,17 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):

_layer_keys = data_plot["layer_keys"]
if len(_layer_keys) > 1:
_layer_keys = ['ALL'] + _layer_keys
self.layer_selector = W.Dropdown(options=_layer_keys,
layout=W.Layout(width="95%"),
)
self.mode_selector = W.Dropdown(options=["line", "map"], value=data_plot["mode"],
# layout=W.Layout(width="5cm"),
layout=W.Layout(width="95%"),
)
_layer_keys = ["ALL"] + _layer_keys
self.layer_selector = W.Dropdown(
options=_layer_keys,
layout=W.Layout(width="95%"),
)
self.mode_selector = W.Dropdown(
options=["line", "map"],
value=data_plot["mode"],
# layout=W.Layout(width="5cm"),
layout=W.Layout(width="95%"),
)
self.scaler = ScaleWidget()
self.channel_selector = ChannelSelector(self.rec0.channel_ids)

Expand All @@ -343,17 +345,17 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):
self.mode_selector,
self.scaler,
# self.channel_selector,
],
],
layout=W.Layout(width="3.5cm"),
align_items='center',
align_items="center",
)

self.return_scaled = data_plot["return_scaled"]

self.widget = widgets.AppLayout(
center=self.figure.canvas,
footer=self.time_slider,
left_sidebar = left_sidebar,
left_sidebar=left_sidebar,
right_sidebar=self.channel_selector,
pane_heights=[0, 6, 1],
pane_widths=ratios,
Expand All @@ -365,28 +367,28 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):

# callbacks:
# some widgets generate a full retrieve + refresh
self.time_slider.observe(self._retrieve_traces, names='value', type="change")
self.layer_selector.observe(self._retrieve_traces, names='value', type="change")
self.channel_selector.observe(self._retrieve_traces, names='value', type="change")
self.time_slider.observe(self._retrieve_traces, names="value", type="change")
self.layer_selector.observe(self._retrieve_traces, names="value", type="change")
self.channel_selector.observe(self._retrieve_traces, names="value", type="change")
# other widgets only refresh
self.scaler.observe(self._update_plot, names='value', type="change")
self.scaler.observe(self._update_plot, names="value", type="change")
# map is a special case because needs to check layer also
self.mode_selector.observe(self._mode_changed, names='value', type="change")
self.mode_selector.observe(self._mode_changed, names="value", type="change")

if backend_kwargs["display"]:
# self.check_backend()
display(self.widget)

def _get_layers(self):
layer = self.layer_selector.value
if layer == 'ALL':
if layer == "ALL":
layer_keys = self.data_plot["layer_keys"]
else:
layer_keys = [layer]
if self.mode_selector.value == "map":
layer_keys = layer_keys[:1]
return layer_keys

def _mode_changed(self, change=None):
if self.mode_selector.value == "map" and self.layer_selector.value == "ALL":
self.layer_selector.value = self.data_plot["layer_keys"][0]
Expand All @@ -400,7 +402,7 @@ def _retrieve_traces(self, change=None):
order, _ = order_channels_by_depth(self.rec0, channel_ids)
else:
order = None

start_frame, end_frame, segment_index = self.time_slider.value
time_range = np.array([start_frame, end_frame]) / self.rec0.sampling_frequency

Expand Down Expand Up @@ -439,9 +441,9 @@ def _update_plot(self, change=None):

data_plot["clims"] = clims
data_plot["channel_ids"] = self._channel_ids

data_plot["layer_keys"] = layer_keys
data_plot["colors"] = {k:self.data_plot["colors"][k] for k in layer_keys}
data_plot["colors"] = {k: self.data_plot["colors"][k] for k in layer_keys}

list_traces = [traces * self.scaler.value for traces in self._list_traces]
data_plot["list_traces"] = list_traces
Expand All @@ -458,7 +460,6 @@ def _update_plot(self, change=None):
fig.canvas.draw()
fig.canvas.flush_events()


def plot_sortingview(self, data_plot, **backend_kwargs):
import sortingview.views as vv
from .utils_sortingview import handle_display_and_url
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/widgets/unit_locations.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):
# a first update
self._update_ipywidget()

self.unit_selector.observe(self._update_ipywidget, names='value', type="change")
self.unit_selector.observe(self._update_ipywidget, names="value", type="change")

if backend_kwargs["display"]:
display(self.widget)
Expand Down
8 changes: 3 additions & 5 deletions src/spikeinterface/widgets/unit_waveforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,6 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):
self.unit_selector = UnitSelector(data_plot["unit_ids"])
self.unit_selector.value = list(data_plot["unit_ids"])[:1]


self.same_axis_button = widgets.Checkbox(
value=False,
description="same axis",
Expand Down Expand Up @@ -309,10 +308,9 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):
# a first update
self._update_ipywidget(None)

self.unit_selector.observe(self._update_ipywidget, names='value', type="change")
self.unit_selector.observe(self._update_ipywidget, names="value", type="change")
for w in self.same_axis_button, self.plot_templates_button, self.hide_axis_button:
w.observe(self._update_ipywidget, names='value', type="change")

w.observe(self._update_ipywidget, names="value", type="change")

if backend_kwargs["display"]:
display(self.widget)
Expand Down Expand Up @@ -340,7 +338,7 @@ def _update_ipywidget(self, change):
data_plot["plot_templates"] = plot_templates
if data_plot["plot_waveforms"]:
data_plot["wfs_by_ids"] = {unit_id: self.we.get_waveforms(unit_id) for unit_id in unit_ids}

# TODO option for plot_legend

backend_kwargs = {}
Expand Down
Loading

0 comments on commit 2c015f7

Please sign in to comment.