From 574955adaf638f89a36f254cd526ecae2253f7be Mon Sep 17 00:00:00 2001 From: FrancescoNegri Date: Tue, 9 Jul 2024 16:53:23 +0200 Subject: [PATCH] Run pre-commit locally --- src/spikeinterface/widgets/unit_spatial.py | 66 +++++++++++++--------- src/spikeinterface/widgets/widget_list.py | 2 +- 2 files changed, 39 insertions(+), 29 deletions(-) diff --git a/src/spikeinterface/widgets/unit_spatial.py b/src/spikeinterface/widgets/unit_spatial.py index 746aaf0a8e..06f253f07f 100644 --- a/src/spikeinterface/widgets/unit_spatial.py +++ b/src/spikeinterface/widgets/unit_spatial.py @@ -7,6 +7,7 @@ from warnings import warn from .base import BaseWidget, to_attr + class UnitSpatialDistributionsWidget(BaseWidget): """ Placeholder documentation to be changed. @@ -18,15 +19,20 @@ class UnitSpatialDistributionsWidget(BaseWidget): depth_axis : int, default: 1 The dimension of unit_locations that is depth """ - + def __init__( - self, - sorting_analyzer, probe=None, - depth_axis=1, bins=None, - cmap="viridis", kde=False, - depth_hist=True, groups=None, - kde_kws=None, - backend=None, **backend_kwargs + self, + sorting_analyzer, + probe=None, + depth_axis=1, + bins=None, + cmap="viridis", + kde=False, + depth_hist=True, + groups=None, + kde_kws=None, + backend=None, + **backend_kwargs, ): sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer) @@ -44,12 +50,12 @@ def __init__( else: # TODO: throw error or warning, no probe available pass - + xrange, yrange, _ = get_auto_lims(probe, margin=0) if bins is None: bins = ( np.round(np.diff(xrange).squeeze() / 75).astype(int), - np.round(np.diff(yrange).squeeze() / 75).astype(int) + np.round(np.diff(yrange).squeeze() / 75).astype(int), ) # TODO: change behaviour, if bins is not defined, bin only along the depth axis @@ -68,7 +74,7 @@ def __init__( cmap=cmap, depth_hist=depth_hist, groups=groups, - kde_kws=kde_kws + kde_kws=kde_kws, ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) @@ -103,16 +109,11 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): np.diff(dp.xrange).squeeze(), np.diff(dp.yrange).squeeze(), facecolor=dp.cmap.colors[0], - fill=True + fill=True, ) ) bg.set_clip_path(patch) - kdeplot( - data, x='x', y='y', - clip=[dp.xrange, dp.yrange], - cmap=dp.cmap, ax=ax, - **kde_kws - ) + kdeplot(data, x="x", y="y", clip=[dp.xrange, dp.yrange], cmap=dp.cmap, ax=ax, **kde_kws) pcm = ax.collections[0] ax.set_xlabel(None) ax.set_ylabel(None) @@ -122,12 +123,12 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): xlim, ylim, _ = get_auto_lims(dp.probe, margin=10) ax.set_xlim(*xlim) ax.set_ylim(*ylim) - ax.spines['top'].set_visible(False) - ax.spines['bottom'].set_visible(False) - ax.spines['right'].set_visible(False) + ax.spines["top"].set_visible(False) + ax.spines["bottom"].set_visible(False) + ax.spines["right"].set_visible(False) ax.set_xticks([]) - ax.set_xlabel('') - ax.set_ylabel('Depth (um)') + ax.set_xlabel("") + ax.set_ylabel("Depth (um)") if dp.depth_hist is True: bbox = ax.get_window_extent() @@ -135,8 +136,17 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax_hist = ax.inset_axes([1, 0, hist_height / bbox.width, 1]) data = dict(y=dp.y) - data['group'] = np.ones(dp.y.size) if dp.groups is None else dp.groups - palette = color_palette('bright', n_colors=1 if dp.groups is None else np.unique(dp.groups).size) - histplot(data=data, y='y', hue='group', bins=dp.bins[1], binrange=dp.yrange, palette=palette, ax=ax_hist, legend=False) - ax_hist.axis('off') - ax_hist.set_ylim(*ylim) \ No newline at end of file + data["group"] = np.ones(dp.y.size) if dp.groups is None else dp.groups + palette = color_palette("bright", n_colors=1 if dp.groups is None else np.unique(dp.groups).size) + histplot( + data=data, + y="y", + hue="group", + bins=dp.bins[1], + binrange=dp.yrange, + palette=palette, + ax=ax_hist, + legend=False, + ) + ax_hist.axis("off") + ax_hist.set_ylim(*ylim) diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index abd9b700e5..ca4159cabb 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -159,4 +159,4 @@ def plot_timeseries(*args, **kwargs): warnings.warn("plot_timeseries() is now plot_traces()") - return plot_traces(*args, **kwargs) \ No newline at end of file + return plot_traces(*args, **kwargs)