Skip to content

Commit

Permalink
Run pre-commit locally
Browse files Browse the repository at this point in the history
  • Loading branch information
FrancescoNegri committed Jul 9, 2024
1 parent 64f9788 commit 574955a
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 29 deletions.
66 changes: 38 additions & 28 deletions src/spikeinterface/widgets/unit_spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from warnings import warn
from .base import BaseWidget, to_attr


class UnitSpatialDistributionsWidget(BaseWidget):
"""
Placeholder documentation to be changed.
Expand All @@ -18,15 +19,20 @@ class UnitSpatialDistributionsWidget(BaseWidget):
depth_axis : int, default: 1
The dimension of unit_locations that is depth
"""

def __init__(
self,
sorting_analyzer, probe=None,
depth_axis=1, bins=None,
cmap="viridis", kde=False,
depth_hist=True, groups=None,
kde_kws=None,
backend=None, **backend_kwargs
self,
sorting_analyzer,
probe=None,
depth_axis=1,
bins=None,
cmap="viridis",
kde=False,
depth_hist=True,
groups=None,
kde_kws=None,
backend=None,
**backend_kwargs,
):
sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer)

Expand All @@ -44,12 +50,12 @@ def __init__(
else:
# TODO: throw error or warning, no probe available
pass

xrange, yrange, _ = get_auto_lims(probe, margin=0)
if bins is None:
bins = (
np.round(np.diff(xrange).squeeze() / 75).astype(int),
np.round(np.diff(yrange).squeeze() / 75).astype(int)
np.round(np.diff(yrange).squeeze() / 75).astype(int),
)
# TODO: change behaviour, if bins is not defined, bin only along the depth axis

Expand All @@ -68,7 +74,7 @@ def __init__(
cmap=cmap,
depth_hist=depth_hist,
groups=groups,
kde_kws=kde_kws
kde_kws=kde_kws,
)

BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs)
Expand Down Expand Up @@ -103,16 +109,11 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
np.diff(dp.xrange).squeeze(),
np.diff(dp.yrange).squeeze(),
facecolor=dp.cmap.colors[0],
fill=True
fill=True,
)
)
bg.set_clip_path(patch)
kdeplot(
data, x='x', y='y',
clip=[dp.xrange, dp.yrange],
cmap=dp.cmap, ax=ax,
**kde_kws
)
kdeplot(data, x="x", y="y", clip=[dp.xrange, dp.yrange], cmap=dp.cmap, ax=ax, **kde_kws)
pcm = ax.collections[0]
ax.set_xlabel(None)
ax.set_ylabel(None)
Expand All @@ -122,21 +123,30 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
xlim, ylim, _ = get_auto_lims(dp.probe, margin=10)
ax.set_xlim(*xlim)
ax.set_ylim(*ylim)
ax.spines['top'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines["top"].set_visible(False)
ax.spines["bottom"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.set_xticks([])
ax.set_xlabel('')
ax.set_ylabel('Depth (um)')
ax.set_xlabel("")
ax.set_ylabel("Depth (um)")

if dp.depth_hist is True:
bbox = ax.get_window_extent()
hist_height = 1.5 * bbox.width

ax_hist = ax.inset_axes([1, 0, hist_height / bbox.width, 1])
data = dict(y=dp.y)
data['group'] = np.ones(dp.y.size) if dp.groups is None else dp.groups
palette = color_palette('bright', n_colors=1 if dp.groups is None else np.unique(dp.groups).size)
histplot(data=data, y='y', hue='group', bins=dp.bins[1], binrange=dp.yrange, palette=palette, ax=ax_hist, legend=False)
ax_hist.axis('off')
ax_hist.set_ylim(*ylim)
data["group"] = np.ones(dp.y.size) if dp.groups is None else dp.groups
palette = color_palette("bright", n_colors=1 if dp.groups is None else np.unique(dp.groups).size)
histplot(
data=data,
y="y",
hue="group",
bins=dp.bins[1],
binrange=dp.yrange,
palette=palette,
ax=ax_hist,
legend=False,
)
ax_hist.axis("off")
ax_hist.set_ylim(*ylim)
2 changes: 1 addition & 1 deletion src/spikeinterface/widgets/widget_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,4 +159,4 @@

def plot_timeseries(*args, **kwargs):
warnings.warn("plot_timeseries() is now plot_traces()")
return plot_traces(*args, **kwargs)
return plot_traces(*args, **kwargs)

0 comments on commit 574955a

Please sign in to comment.