From 8054b3aa82a52495c3b22dea6ba2629cb22d42ac Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 15 May 2024 19:16:10 +0200 Subject: [PATCH 01/23] Extend plot waveforms/templates to Templates object --- src/spikeinterface/widgets/unit_depths.py | 2 +- src/spikeinterface/widgets/unit_summary.py | 2 +- src/spikeinterface/widgets/unit_templates.py | 6 +- src/spikeinterface/widgets/unit_waveforms.py | 233 ++++++++++++------ .../widgets/unit_waveforms_density_map.py | 2 +- src/spikeinterface/widgets/utils.py | 11 +- 6 files changed, 177 insertions(+), 79 deletions(-) diff --git a/src/spikeinterface/widgets/unit_depths.py b/src/spikeinterface/widgets/unit_depths.py index c5fe3e05e8..c2e9c06863 100644 --- a/src/spikeinterface/widgets/unit_depths.py +++ b/src/spikeinterface/widgets/unit_depths.py @@ -35,7 +35,7 @@ def __init__( unit_ids = sorting_analyzer.sorting.unit_ids if unit_colors is None: - unit_colors = get_unit_colors(sorting_analyzer.sorting) + unit_colors = get_unit_colors(sorting_analyzer) colors = [unit_colors[unit_id] for unit_id in unit_ids] diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index ea6476784e..0b2a348edf 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -48,7 +48,7 @@ def __init__( sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer) if unit_colors is None: - unit_colors = get_unit_colors(sorting_analyzer.sorting) + unit_colors = get_unit_colors(sorting_analyzer) plot_data = dict( sorting_analyzer=sorting_analyzer, diff --git a/src/spikeinterface/widgets/unit_templates.py b/src/spikeinterface/widgets/unit_templates.py index 1350bb71a5..eb9a90d1d1 100644 --- a/src/spikeinterface/widgets/unit_templates.py +++ b/src/spikeinterface/widgets/unit_templates.py @@ -1,5 +1,6 @@ from __future__ import annotations +from ..core import SortingAnalyzer from .unit_waveforms import UnitWaveformsWidget from .base import to_attr @@ -17,6 +18,9 @@ def plot_sortingview(self, data_plot, **backend_kwargs): dp = to_attr(data_plot) + sorting_analyzer = dp.sorting_analyzer_or_templates + assert isinstance(sorting_analyzer, SortingAnalyzer), "This widget requires a SortingAnalyzer as input" + assert len(dp.templates_shading) <= 4, "Only 2 ans 4 templates shading are supported in sortingview" # ensure serializable for sortingview @@ -50,7 +54,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs): v_average_waveforms = vv.AverageWaveforms(average_waveforms=aw_items, channel_locations=locations) if not dp.hide_unit_selector: - v_units_table = generate_unit_table_view(dp.sorting_analyzer.sorting) + v_units_table = generate_unit_table_view(sorting_analyzer.sorting) self.view = vv.Box( direction="horizontal", diff --git a/src/spikeinterface/widgets/unit_waveforms.py b/src/spikeinterface/widgets/unit_waveforms.py index f701f9a868..2b3dc7ed34 100644 --- a/src/spikeinterface/widgets/unit_waveforms.py +++ b/src/spikeinterface/widgets/unit_waveforms.py @@ -6,7 +6,7 @@ from .base import BaseWidget, to_attr from .utils import get_unit_colors -from ..core import ChannelSparsity, SortingAnalyzer +from ..core import ChannelSparsity, SortingAnalyzer, Templates from ..core.basesorting import BaseSorting @@ -16,8 +16,9 @@ class UnitWaveformsWidget(BaseWidget): Parameters ---------- - sorting_analyzer : SortingAnalyzer - The SortingAnalyzer + sorting_analyzer_or_templates : SortingAnalyzer | Templates + The SortingAnalyzer or Templates object. + If Templates is given, the "plot_waveforms" argument is set to False channel_ids: list or None, default: None The channel ids to display unit_ids : list or None, default: None @@ -39,6 +40,8 @@ class UnitWaveformsWidget(BaseWidget): displayed per waveform, (matplotlib backend) scale : float, default: 1 Scale factor for the waveforms/templates (matplotlib backend) + widen_narrow_scale : float, default: 1 + Scale factor for the x-axis of the waveforms/templates (matplotlib backend) axis_equal : bool, default: False Equal aspect ratio for x and y axis, to visualize the array geometry to scale lw_waveforms : float, default: 1 @@ -64,6 +67,8 @@ class UnitWaveformsWidget(BaseWidget): are used for the lower bounds, and the second half for the upper bounds. Inner elements produce darker shadings. For sortingview backend only 2 or 4 elements are supported. + scalebar : bool, default: False + Display a scale bar on the waveforms plot (matplotlib backend) hide_unit_selector : bool, default: False For sortingview backend, if True the unit selector is not displayed same_axis : bool, default: False @@ -77,7 +82,7 @@ class UnitWaveformsWidget(BaseWidget): def __init__( self, - sorting_analyzer: SortingAnalyzer, + sorting_analyzer_or_templates: SortingAnalyzer | Templates, channel_ids=None, unit_ids=None, plot_waveforms=True, @@ -87,6 +92,7 @@ def __init__( sparsity=None, ncols=5, scale=1, + widen_narrow_scale=1, lw_waveforms=1, lw_templates=2, axis_equal=False, @@ -96,6 +102,7 @@ def __init__( same_axis=False, shade_templates=True, templates_percentile_shading=(1, 25, 75, 99), + scalebar=False, x_offset_units=False, alpha_waveforms=0.5, alpha_templates=1, @@ -104,29 +111,29 @@ def __init__( backend=None, **backend_kwargs, ): - - sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer) - sorting: BaseSorting = sorting_analyzer.sorting + if not isinstance(sorting_analyzer_or_templates, Templates): + sorting_analyzer_or_templates = self.ensure_sorting_analyzer(sorting_analyzer_or_templates) + else: + plot_waveforms = False + shade_templates = False if unit_ids is None: - unit_ids = sorting.unit_ids + unit_ids = sorting_analyzer_or_templates.unit_ids if channel_ids is None: - channel_ids = sorting_analyzer.channel_ids + channel_ids = sorting_analyzer_or_templates.channel_ids if unit_colors is None: - unit_colors = get_unit_colors(sorting) - - channel_locations = sorting_analyzer.get_channel_locations()[ - sorting_analyzer.channel_ids_to_indices(channel_ids) - ] + unit_colors = get_unit_colors(sorting_analyzer_or_templates) + channel_indices = [list(sorting_analyzer_or_templates.channel_ids).index(ch) for ch in channel_ids] + channel_locations = sorting_analyzer_or_templates.get_channel_locations()[channel_indices] extra_sparsity = False - if sorting_analyzer.is_sparse(): + if sorting_analyzer_or_templates.sparsity is not None: if sparsity is None: - sparsity = sorting_analyzer.sparsity + sparsity = sorting_analyzer_or_templates.sparsity else: # assert provided sparsity is a subset of waveform sparsity - combined_mask = np.logical_or(sorting_analyzer.sparsity.mask, sparsity.mask) - assert np.all(np.sum(combined_mask, 1) - np.sum(sorting_analyzer.sparsity.mask, 1) == 0), ( + combined_mask = np.logical_or(sorting_analyzer_or_templates.sparsity.mask, sparsity.mask) + assert np.all(np.sum(combined_mask, 1) - np.sum(sorting_analyzer_or_templates.sparsity.mask, 1) == 0), ( "The provided 'sparsity' needs to include only the sparse channels " "used to extract waveforms (for example, by using a smaller 'radius_um')." ) @@ -134,41 +141,49 @@ def __init__( else: if sparsity is None: # in this case, we construct a dense sparsity - unit_id_to_channel_ids = {u: sorting_analyzer.channel_ids for u in sorting_analyzer.unit_ids} + unit_id_to_channel_ids = { + u: sorting_analyzer_or_templates.channel_ids for u in sorting_analyzer_or_templates.unit_ids + } sparsity = ChannelSparsity.from_unit_id_to_channel_ids( unit_id_to_channel_ids=unit_id_to_channel_ids, - unit_ids=sorting_analyzer.unit_ids, - channel_ids=sorting_analyzer.channel_ids, + unit_ids=sorting_analyzer_or_templates.unit_ids, + channel_ids=sorting_analyzer_or_templates.channel_ids, ) else: assert isinstance(sparsity, ChannelSparsity), "'sparsity' should be a ChannelSparsity object!" # get templates - self.templates_ext = sorting_analyzer.get_extension("templates") - assert self.templates_ext is not None, "plot_waveforms() need extension 'templates'" - templates = self.templates_ext.get_templates(unit_ids=unit_ids, operator="average") - - if templates_percentile_shading is not None and not sorting_analyzer.has_extension("waveforms"): - warn( - "templates_percentile_shading can only be used if the 'waveforms' extension is available. " - "Settimg templates_percentile_shading to None." - ) - templates_percentile_shading = None - templates_shading = self._get_template_shadings(sorting_analyzer, unit_ids, templates_percentile_shading) - - xvectors, y_scale, y_offset, delta_x = get_waveforms_scales( - sorting_analyzer, templates, channel_locations, x_offset_units - ) + if isinstance(sorting_analyzer_or_templates, Templates): + templates = sorting_analyzer_or_templates.templates_array + nbefore = sorting_analyzer_or_templates.nbefore + self.templates_ext = None + templates_shading = None + else: + self.templates_ext = sorting_analyzer_or_templates.get_extension("templates") + assert self.templates_ext is not None, "plot_waveforms() need extension 'templates'" + templates = self.templates_ext.get_templates(unit_ids=unit_ids, operator="average") + nbefore = self.templates_ext.nbefore + + if templates_percentile_shading is not None and not sorting_analyzer_or_templates.has_extension( + "waveforms" + ): + warn( + "templates_percentile_shading can only be used if the 'waveforms' extension is available. " + "Settimg templates_percentile_shading to None." + ) + templates_percentile_shading = None + templates_shading = self._get_template_shadings(unit_ids, templates_percentile_shading) wfs_by_ids = {} if plot_waveforms: - wf_ext = sorting_analyzer.get_extension("waveforms") + # this must be a sorting_analyzer + wf_ext = sorting_analyzer_or_templates.get_extension("waveforms") if wf_ext is None: raise ValueError("plot_waveforms() needs the extension 'waveforms'") for unit_id in unit_ids: - unit_index = list(sorting.unit_ids).index(unit_id) + unit_index = list(sorting_analyzer_or_templates.unit_ids).index(unit_id) if not extra_sparsity: - if sorting_analyzer.is_sparse(): + if sorting_analyzer_or_templates.is_sparse(): # wfs = we.get_waveforms(unit_id) wfs = wf_ext.get_waveforms_one_unit(unit_id, force_dense=False) else: @@ -181,7 +196,7 @@ def __init__( # wfs = we.get_waveforms(unit_id) wfs = wf_ext.get_waveforms_one_unit(unit_id, force_dense=False) # find additional slice to apply to sparse waveforms - (wfs_sparse_indices,) = np.nonzero(sorting_analyzer.sparsity.mask[unit_index]) + (wfs_sparse_indices,) = np.nonzero(sorting_analyzer_or_templates.sparsity.mask[unit_index]) (extra_sparse_indices,) = np.nonzero(sparsity.mask[unit_index]) (extra_slice,) = np.nonzero(np.isin(wfs_sparse_indices, extra_sparse_indices)) # apply extra sparsity @@ -189,14 +204,16 @@ def __init__( wfs_by_ids[unit_id] = wfs plot_data = dict( - sorting_analyzer=sorting_analyzer, - sampling_frequency=sorting_analyzer.sampling_frequency, + sorting_analyzer_or_templates=sorting_analyzer_or_templates, + sampling_frequency=sorting_analyzer_or_templates.sampling_frequency, + nbefore=nbefore, unit_ids=unit_ids, channel_ids=channel_ids, sparsity=sparsity, unit_colors=unit_colors, channel_locations=channel_locations, scale=scale, + widen_narrow_scale=widen_narrow_scale, templates=templates, templates_shading=templates_shading, do_shading=shade_templates, @@ -207,19 +224,16 @@ def __init__( unit_selected_waveforms=unit_selected_waveforms, axis_equal=axis_equal, max_spikes_per_unit=max_spikes_per_unit, - xvectors=xvectors, - y_scale=y_scale, - y_offset=y_offset, wfs_by_ids=wfs_by_ids, set_title=set_title, same_axis=same_axis, + scalebar=scalebar, templates_percentile_shading=templates_percentile_shading, x_offset_units=x_offset_units, lw_waveforms=lw_waveforms, lw_templates=lw_templates, alpha_waveforms=alpha_waveforms, alpha_templates=alpha_templates, - delta_x=delta_x, hide_unit_selector=hide_unit_selector, plot_legend=plot_legend, ) @@ -245,6 +259,10 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + xvectors, y_scale, y_offset, delta_x = get_waveforms_scales( + dp.templates, dp.channel_locations, dp.nbefore, dp.x_offset_units, dp.widen_narrow_scale + ) + for i, unit_id in enumerate(dp.unit_ids): if dp.same_axis: ax = self.ax @@ -253,7 +271,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): color = dp.unit_colors[unit_id] chan_inds = dp.sparsity.unit_id_to_channel_indices[unit_id] - xvectors_flat = dp.xvectors[:, chan_inds].T.flatten() + xvectors_flat = xvectors[:, chan_inds].T.flatten() # plot waveforms if dp.plot_waveforms: @@ -265,12 +283,12 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): random_idxs = np.random.permutation(len(wfs))[: dp.max_spikes_per_unit] wfs = wfs[random_idxs] - wfs = wfs * dp.y_scale + dp.y_offset[None, :, chan_inds] + wfs = wfs * y_scale + y_offset[None, :, chan_inds] wfs_flat = wfs.swapaxes(1, 2).reshape(wfs.shape[0], -1).T if dp.x_offset_units: # 0.7 is to match spacing in xvect - xvec = xvectors_flat + i * 0.7 * dp.delta_x + xvec = xvectors_flat + i * 0.7 * delta_x else: xvec = xvectors_flat @@ -278,14 +296,33 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): if not dp.plot_templates: ax.get_lines()[-1].set_label(f"{unit_id}") + if not dp.plot_templates and dp.scalebar and not dp.same_axis: + # xscale + min_wfs = np.min(wfs_flat) + wfs_for_scale = dp.wfs_by_ids[unit_id] * y_scale + offset = 0.1 * (np.max(wfs_flat) - np.min(wfs_flat)) + xargmin = np.nanargmin(xvec) + xscale_bar = [xvec[xargmin], xvec[xargmin + dp.nbefore]] + ax.plot(xscale_bar, [min_wfs - offset, min_wfs - offset], color="k") + nbefore_time = int(dp.nbefore / dp.sampling_frequency * 1000) + ax.text( + xscale_bar[0] + xscale_bar[1] // 3, min_wfs - 1.5 * offset, f"{nbefore_time} ms", fontsize=8 + ) + + # yscale + length = int(np.ptp(wfs_flat) // 5) + length_uv = int(np.ptp(wfs_for_scale) // 5) + x_offset = xscale_bar[0] - np.ptp(xscale_bar) // 2 + ax.plot([xscale_bar[0], xscale_bar[0]], [min_wfs - offset, min_wfs - offset + length], color="k") + ax.text(x_offset, min_wfs - offset + length // 3, f"{length_uv} $\mu$V", fontsize=8, rotation=90) # plot template if dp.plot_templates: - template = dp.templates[i, :, :][:, chan_inds] * dp.scale * dp.y_scale + dp.y_offset[:, chan_inds] + template = dp.templates[i, :, :][:, chan_inds] * dp.scale * y_scale + y_offset[:, chan_inds] if dp.x_offset_units: # 0.7 is to match spacing in xvect - xvec = xvectors_flat + i * 0.7 * dp.delta_x + xvec = xvectors_flat + i * 0.7 * delta_x else: xvec = xvectors_flat # plot template shading if waveforms are not plotted @@ -297,12 +334,11 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): shading_alphas = np.linspace(lightest_gray_alpha, darkest_gray_alpha, n_shadings) for s in range(n_shadings): lower_bound = ( - dp.templates_shading[s][i, :, :][:, chan_inds] * dp.scale * dp.y_scale - + dp.y_offset[:, chan_inds] + dp.templates_shading[s][i, :, :][:, chan_inds] * dp.scale * y_scale + y_offset[:, chan_inds] ) upper_bound = ( - dp.templates_shading[n_percentiles - 1 - s][i, :, :][:, chan_inds] * dp.scale * dp.y_scale - + dp.y_offset[:, chan_inds] + dp.templates_shading[n_percentiles - 1 - s][i, :, :][:, chan_inds] * dp.scale * y_scale + + y_offset[:, chan_inds] ) ax.fill_between( xvec, @@ -332,6 +368,26 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): if dp.set_title: ax.set_title(f"template {template_label}") + if not dp.plot_waveforms and dp.scalebar and not dp.same_axis: + # xscale + template_for_scale = dp.templates[i, :, :][:, chan_inds] * dp.scale + min_wfs = np.min(template) + offset = 0.1 * (np.max(template) - np.min(template)) + xargmin = np.nanargmin(xvec) + xscale_bar = [xvec[xargmin], xvec[xargmin + dp.nbefore]] + ax.plot(xscale_bar, [min_wfs - offset, min_wfs - offset], color="k") + nbefore_time = int(dp.nbefore / dp.sampling_frequency * 1000) + ax.text( + xscale_bar[0] + xscale_bar[1] // 3, min_wfs - 1.5 * offset, f"{nbefore_time} ms", fontsize=8 + ) + + # yscale + length = int(np.ptp(template) // 5) + length_uv = int(np.ptp(template_for_scale) // 5) + x_offset = xscale_bar[0] - np.ptp(xscale_bar) // 2 + ax.plot([xscale_bar[0], xscale_bar[0]], [min_wfs - offset, min_wfs - offset + length], color="k") + ax.text(x_offset, min_wfs - offset + length // 3, f"{length_uv} $\mu$V", fontsize=8, rotation=90) + # plot channels if dp.plot_channels: # TODO enhance this @@ -348,14 +404,19 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt import ipywidgets.widgets as widgets from IPython.display import display - from .utils_ipywidgets import check_ipywidget_backend, UnitSelector, ScaleWidget + from .utils_ipywidgets import check_ipywidget_backend, UnitSelector, ScaleWidget, WidenNarrowWidget check_ipywidget_backend() self.next_data_plot = data_plot.copy() cm = 1 / 2.54 - self.sorting_analyzer = data_plot["sorting_analyzer"] + if isinstance(data_plot["sorting_analyzer_or_templates"], SortingAnalyzer): + self.sorting_analyzer = data_plot["sorting_analyzer_or_templates"] + self.templates = None + else: + self.sorting_analyzer = None + self.templates = data_plot["sorting_analyzer_or_templates"] width_cm = backend_kwargs["width_cm"] height_cm = backend_kwargs["height_cm"] @@ -375,6 +436,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.unit_selector = UnitSelector(data_plot["unit_ids"], layout=widgets.Layout(height="80%")) self.unit_selector.value = list(data_plot["unit_ids"])[:1] self.scaler = ScaleWidget(value=data_plot["scale"], layout=widgets.Layout(height="20%")) + self.widen_narrow = WidenNarrowWidget(value=1.0, layout=widgets.Layout(height="20%")) self.same_axis_button = widgets.Checkbox( value=False, @@ -400,10 +462,20 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): disabled=False, ) - footer = widgets.HBox( - [self.same_axis_button, self.plot_templates_button, self.template_shading_button, self.hide_axis_button] + self.scalebar = widgets.Checkbox( + value=False, + description="scalebar", + disabled=False, ) - left_sidebar = widgets.VBox([self.unit_selector, self.scaler]) + if self.sorting_analyzer is not None: + footer_list = [self.same_axis_button, self.template_shading_button, self.hide_axis_button, self.scalebar] + else: + footer_list = [self.same_axis_button, self.hide_axis_button, self.scalebar] + if data_plot["plot_waveforms"]: + footer_list.append(self.plot_templates_button) + + footer = widgets.HBox(footer_list) + left_sidebar = widgets.VBox([self.unit_selector, self.scaler, self.widen_narrow]) self.widget = widgets.AppLayout( center=self.fig_wf.canvas, @@ -418,13 +490,20 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.unit_selector.observe(self._update_plot, names="value", type="change") self.scaler.observe(self._update_plot, names="value", type="change") - for w in self.same_axis_button, self.plot_templates_button, self.template_shading_button, self.hide_axis_button: + self.widen_narrow.observe(self._update_plot, names="value", type="change") + for w in ( + self.same_axis_button, + self.plot_templates_button, + self.template_shading_button, + self.hide_axis_button, + self.scalebar, + ): w.observe(self._update_plot, names="value", type="change") if backend_kwargs["display"]: display(self.widget) - def _get_template_shadings(self, sorting_analyzer, unit_ids, templates_percentile_shading): + def _get_template_shadings(self, unit_ids, templates_percentile_shading): templates = self.templates_ext.get_templates(unit_ids=unit_ids, operator="average") if templates_percentile_shading is None: @@ -460,30 +539,40 @@ def _update_plot(self, change): hide_axis = self.hide_axis_button.value do_shading = self.template_shading_button.value - wf_ext = self.sorting_analyzer.get_extension("waveforms") - templates = self.templates_ext.get_templates(unit_ids=unit_ids, operator="average") + if self.sorting_analyzer is not None: + templates = self.templates_ext.get_templates(unit_ids=unit_ids, operator="average") + templates_shadings = self._get_template_shadings(unit_ids, data_plot["templates_percentile_shading"]) + channel_locations = self.sorting_analyzer.get_channel_locations() + + else: + unit_indices = [list(self.templates.unit_ids).index(unit_id) for unit_id in unit_ids] + templates = self.templates.templates_array[unit_indices] + templates_shadings = None + channel_locations = self.templates.get_channel_locations() # matplotlib next_data_plot dict update at each call data_plot = self.next_data_plot data_plot["unit_ids"] = unit_ids data_plot["templates"] = templates - templates_shadings = self._get_template_shadings( - self.sorting_analyzer, unit_ids, data_plot["templates_percentile_shading"] - ) data_plot["templates_shading"] = templates_shadings data_plot["same_axis"] = same_axis data_plot["plot_templates"] = plot_templates data_plot["do_shading"] = do_shading data_plot["scale"] = self.scaler.value + data_plot["widen_narrow_scale"] = self.widen_narrow.value + + if same_axis: + self.scalebar.value = False + data_plot["scalebar"] = self.scalebar.value + if data_plot["plot_waveforms"]: + wf_ext = self.sorting_analyzer.get_extension("waveforms") data_plot["wfs_by_ids"] = { unit_id: wf_ext.get_waveforms_one_unit(unit_id, force_dense=False) for unit_id in unit_ids } # TODO option for plot_legend - backend_kwargs = {} - if same_axis: backend_kwargs["ax"] = self.fig_wf.add_subplot() data_plot["set_title"] = False @@ -502,7 +591,6 @@ def _update_plot(self, change): ax.axis("off") # update probe plot - channel_locations = self.sorting_analyzer.get_channel_locations() self.ax_probe.plot( channel_locations[:, 0], channel_locations[:, 1], ls="", marker="o", color="gray", markersize=2, alpha=0.5 ) @@ -529,7 +617,7 @@ def _update_plot(self, change): fig_probe.canvas.flush_events() -def get_waveforms_scales(sorting_analyzer, templates, channel_locations, x_offset_units=False): +def get_waveforms_scales(templates, channel_locations, nbefore, x_offset_units=False, widen_narrow_scale=1.0): """ Return scales and x_vector for templates plotting """ @@ -555,10 +643,9 @@ def get_waveforms_scales(sorting_analyzer, templates, channel_locations, x_offse y_offset = channel_locations[:, 1][None, :] - nbefore = sorting_analyzer.get_extension("templates").nbefore nsamples = templates.shape[1] - xvect = delta_x * (np.arange(nsamples) - nbefore) / nsamples * 0.7 + xvect = (delta_x * widen_narrow_scale) * (np.arange(nsamples) - nbefore) / nsamples * 0.7 if x_offset_units: ch_locs = channel_locations diff --git a/src/spikeinterface/widgets/unit_waveforms_density_map.py b/src/spikeinterface/widgets/unit_waveforms_density_map.py index 2e7ec883e6..6ef1a7a782 100644 --- a/src/spikeinterface/widgets/unit_waveforms_density_map.py +++ b/src/spikeinterface/widgets/unit_waveforms_density_map.py @@ -57,7 +57,7 @@ def __init__( unit_ids = sorting_analyzer.unit_ids if unit_colors is None: - unit_colors = get_unit_colors(sorting_analyzer.sorting) + unit_colors = get_unit_colors(sorting_analyzer) if use_max_channel: assert len(unit_ids) == 1, " UnitWaveformDensity : use_max_channel=True works only with one unit" diff --git a/src/spikeinterface/widgets/utils.py b/src/spikeinterface/widgets/utils.py index 29e6474ee9..3461fb179a 100644 --- a/src/spikeinterface/widgets/utils.py +++ b/src/spikeinterface/widgets/utils.py @@ -95,12 +95,19 @@ def get_some_colors(keys, color_engine="auto", map_name="gist_ncar", format="RGB return dict_colors -def get_unit_colors(sorting, color_engine="auto", map_name="gist_ncar", format="RGBA", shuffle=None, seed=None): +def get_unit_colors( + sorting_or_analyzer_or_templates, color_engine="auto", map_name="gist_ncar", format="RGBA", shuffle=None, seed=None +): """ Return a dict colors per units. """ colors = get_some_colors( - sorting.unit_ids, color_engine=color_engine, map_name=map_name, format=format, shuffle=shuffle, seed=seed + sorting_or_analyzer_or_templates.unit_ids, + color_engine=color_engine, + map_name=map_name, + format=format, + shuffle=shuffle, + seed=seed, ) return colors From f073d5a559e28e490cfc95ce51f34a81244351dd Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 15 May 2024 19:22:15 +0200 Subject: [PATCH 02/23] Add tests --- src/spikeinterface/widgets/tests/test_widgets.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 59ba29ef73..fd31bd31dc 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -1,4 +1,5 @@ import unittest +from numba.cuda import Out import pytest import os from pathlib import Path @@ -286,6 +287,16 @@ def test_plot_unit_templates(self): backend=backend, **self.backend_kwargs[backend], ) + # test with templates + templates_ext = self.sorting_analyzer_dense.get_extension("templates") + templates = templates_ext.get_data(outputs="Templates") + sw.plot_unit_templates( + templates, + sparsity=self.sparsity_strict, + unit_ids=unit_ids, + backend=backend, + **self.backend_kwargs[backend], + ) else: # sortingview doesn't support more than 2 shadings with self.assertRaises(AssertionError): From 80b09e54254c2a7dbb176ca16a9d4a27a9d0ba3e Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Tue, 21 May 2024 15:58:14 +0200 Subject: [PATCH 03/23] Update src/spikeinterface/widgets/tests/test_widgets.py --- src/spikeinterface/widgets/tests/test_widgets.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index fd31bd31dc..5366fb864f 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -1,5 +1,4 @@ import unittest -from numba.cuda import Out import pytest import os from pathlib import Path From f309c79dcc7ac7ea069f83482f0acfb28d7af83d Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 21 May 2024 16:05:31 +0200 Subject: [PATCH 04/23] Add WidenNarrowWidget for templates --- .../widgets/utils_ipywidgets.py | 45 +++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/src/spikeinterface/widgets/utils_ipywidgets.py b/src/spikeinterface/widgets/utils_ipywidgets.py index 6d91140500..0d81a8184f 100644 --- a/src/spikeinterface/widgets/utils_ipywidgets.py +++ b/src/spikeinterface/widgets/utils_ipywidgets.py @@ -329,6 +329,51 @@ def value_changed(self, change=None): self.update_label() +class WidenNarrowWidget(W.VBox): + value = traitlets.Float() + + def __init__(self, value=1.0, factor=1.2, **kwargs): + assert factor > 1.0 + self.factor = factor + + self.scale_label = W.Label("Widen/Narrow", layout=W.Layout(width="95%", justify_content="center")) + + self.right_selector = W.Button( + description="", + disabled=False, + button_style="", # 'success', 'info', 'warning', 'danger' or '' + tooltip="Increase horizontal scale", + icon="arrow-right", + # layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), + layout=W.Layout(width="60%", align_self="center"), + ) + + self.left_selector = W.Button( + description="", + disabled=False, + button_style="", # 'success', 'info', 'warning', 'danger' or '' + tooltip="Decrease horizontal scale", + icon="arrow-left", + # layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), + layout=W.Layout(width="60%", align_self="center"), + ) + + self.right_selector.on_click(self.left_clicked) + self.left_selector.on_click(self.right_clicked) + + self.value = value + super(W.VBox, self).__init__( + children=[self.scale_label, W.HBox([self.left_selector, self.right_selector])], + **kwargs, + ) + + def left_clicked(self, change=None): + self.value = self.value / self.factor + + def right_clicked(self, change=None): + self.value = self.value * self.factor + + class UnitSelector(W.VBox): value = traitlets.List() From 345f5edd9491532f634385aee0e0d3117cb220c6 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 21 May 2024 16:08:03 +0200 Subject: [PATCH 05/23] Fix bug with analyzer --- src/spikeinterface/widgets/unit_waveforms.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/widgets/unit_waveforms.py b/src/spikeinterface/widgets/unit_waveforms.py index 2b3dc7ed34..f6e16abaae 100644 --- a/src/spikeinterface/widgets/unit_waveforms.py +++ b/src/spikeinterface/widgets/unit_waveforms.py @@ -541,9 +541,10 @@ def _update_plot(self, change): if self.sorting_analyzer is not None: templates = self.templates_ext.get_templates(unit_ids=unit_ids, operator="average") - templates_shadings = self._get_template_shadings(unit_ids, data_plot["templates_percentile_shading"]) + templates_shadings = self._get_template_shadings( + unit_ids, self.next_data_plot["templates_percentile_shading"] + ) channel_locations = self.sorting_analyzer.get_channel_locations() - else: unit_indices = [list(self.templates.unit_ids).index(unit_id) for unit_id in unit_ids] templates = self.templates.templates_array[unit_indices] From 614a84cfb955668b87a53d9b5d9642e2a662ccdb Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 21 May 2024 17:40:36 +0200 Subject: [PATCH 06/23] Fix the new way of handling cmap in matpltolib. This fix the matplotlib 3.9 problem related to this. --- .../benchmark/benchmark_motion_estimation.py | 2 +- .../benchmark/benchmark_peak_selection.py | 2 +- .../sortingcomponents/clustering/clustering_tools.py | 6 +++--- .../sortingcomponents/clustering/sliding_hdbscan.py | 2 +- src/spikeinterface/sortingcomponents/clustering/split.py | 2 +- src/spikeinterface/widgets/collision.py | 4 ++-- src/spikeinterface/widgets/motion.py | 2 +- src/spikeinterface/widgets/multicomparison.py | 8 ++++---- src/spikeinterface/widgets/utils.py | 5 +++-- 9 files changed, 17 insertions(+), 16 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py index 30175288a3..3c5623f202 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py @@ -728,7 +728,7 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, colors=None, fig # n = self.motion.shape[1] # step = int(np.ceil(max(1, n / show_only))) -# colors = plt.cm.get_cmap("jet", n) +# colors = plt.colormaps["jet"].resampled(n) # for i in range(0, n, step): # ax = axs[0] # ax.plot(self.temporal_bins, self.gt_motion[:, i], lw=1.5, ls="--", color=colors(i)) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py index d3875ca33d..008de2d931 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py @@ -382,7 +382,7 @@ def create_benchmark(self, key): # import matplotlib -# my_cmap = plt.get_cmap(cmap) +# my_cmap = plt.colormaps[cmap] # cNorm = matplotlib.colors.Normalize(vmin=clim[0], vmax=clim[1]) # scalarMap = plt.cm.ScalarMappable(norm=cNorm, cmap=my_cmap) diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index d3a00c4e6e..083e0077f6 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -62,7 +62,7 @@ def _split_waveforms( local_feature_plot = local_feature unique_lab = np.unique(local_labels_with_noise) - cmap = plt.get_cmap("jet", unique_lab.size) + cmap = plt.colormaps["jet"].resampled(unique_lab.size) cmap = {k: cmap(l) for l, k in enumerate(unique_lab)} cmap[-1] = "k" active_ind = np.arange(local_feature.shape[0]) @@ -145,7 +145,7 @@ def _split_waveforms_nested( local_feature_plot = reducer.fit_transform(local_feature) unique_lab = np.unique(active_labels_with_noise) - cmap = plt.get_cmap("jet", unique_lab.size) + cmap = plt.colormaps["jet"].resampled(unique_lab.size) cmap = {k: cmap(l) for l, k in enumerate(unique_lab)} cmap[-1] = "k" cmap[-2] = "b" @@ -276,7 +276,7 @@ def auto_split_clustering( fig, ax = plt.subplots() plot_labels_set = np.unique(local_labels_with_noise) - cmap = plt.get_cmap("jet", plot_labels_set.size) + cmap = plt.colormaps["jet"].resampled(plot_labels_set.size) cmap = {k: cmap(l) for l, k in enumerate(plot_labels_set)} cmap[-1] = "k" cmap[-2] = "b" diff --git a/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py b/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py index 7e7a8de1d7..2ae22ce07d 100644 --- a/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py +++ b/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py @@ -349,7 +349,7 @@ def _find_clusters(cls, recording, peaks, wfs_arrays, sparsity_mask, noise, d): wfs_no_noise = wfs[: -noise.shape[0]] fig, axs = plt.subplots(ncols=3) - cmap = plt.get_cmap("jet", np.unique(local_labels).size) + cmap = plt.colormaps["jet"].resampled(np.unique(local_labels).size) cmap = {label: cmap(l) for l, label in enumerate(local_labels_set)} cmap[-1] = "k" for label in local_labels_set: diff --git a/src/spikeinterface/sortingcomponents/clustering/split.py b/src/spikeinterface/sortingcomponents/clustering/split.py index ceeaeb6633..45f2f44753 100644 --- a/src/spikeinterface/sortingcomponents/clustering/split.py +++ b/src/spikeinterface/sortingcomponents/clustering/split.py @@ -254,7 +254,7 @@ def split( import matplotlib.pyplot as plt labels_set = np.setdiff1d(possible_labels, [-1]) - colors = plt.get_cmap("tab10", len(labels_set)) + colors = plt.colormaps["tab10"].resampled(len(labels_set)) colors = {k: colors(i) for i, k in enumerate(labels_set)} colors[-1] = "k" fix, axs = plt.subplots(nrows=2) diff --git a/src/spikeinterface/widgets/collision.py b/src/spikeinterface/widgets/collision.py index a5b5891110..34f65a2f89 100644 --- a/src/spikeinterface/widgets/collision.py +++ b/src/spikeinterface/widgets/collision.py @@ -136,7 +136,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax1.set_xlabel("lag (ms)") elif dp.mode == "lines": - my_cmap = plt.get_cmap(dp.cmap) + my_cmap = plt.colormaps[dp.cmap] cNorm = matplotlib.colors.Normalize(vmin=dp.similarity_bins.min(), vmax=dp.similarity_bins.max()) scalarMap = plt.cm.ScalarMappable(norm=cNorm, cmap=my_cmap) @@ -245,7 +245,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): study = dp.study - my_cmap = plt.get_cmap(dp.cmap) + my_cmap = plt.colormaps[dp.cmap] cNorm = matplotlib.colors.Normalize(vmin=dp.similarity_bins.min(), vmax=dp.similarity_bins.max()) scalarMap = plt.cm.ScalarMappable(norm=cNorm, cmap=my_cmap) study.precompute_scores_by_similarities( diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index 2e4efc82b0..9d64c89e46 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -128,7 +128,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): if dp.scatter_decimate is not None: amps = amps[:: dp.scatter_decimate] amps_abs = amps_abs[:: dp.scatter_decimate] - cmap = plt.get_cmap(dp.amplitude_cmap) + cmap = plt.colormaps[dp.amplitude_cmap] if dp.amplitude_clim is None: amps = amps_abs amps /= q_95 diff --git a/src/spikeinterface/widgets/multicomparison.py b/src/spikeinterface/widgets/multicomparison.py index 78693aacc2..2d4a22a2b3 100644 --- a/src/spikeinterface/widgets/multicomparison.py +++ b/src/spikeinterface/widgets/multicomparison.py @@ -87,7 +87,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): nodelist=sorted(g.nodes), edge_color=edge_col, alpha=dp.alpha_edges, - edge_cmap=plt.cm.get_cmap(dp.edge_cmap), + edge_cmap=plt.colormaps[dp.edge_cmap], edge_vmin=mcmp.match_score, edge_vmax=1, ax=self.ax, @@ -106,7 +106,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt norm = mpl_colors.Normalize(vmin=mcmp.match_score, vmax=1) - cmap = plt.cm.get_cmap(dp.edge_cmap) + cmap = plt.colormaps[dp.edge_cmap] m = plt.cm.ScalarMappable(norm=norm, cmap=cmap) self.figure.colorbar(m) @@ -159,7 +159,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) mcmp = dp.multi_comparison - cmap = plt.get_cmap(dp.cmap) + cmap = plt.colormaps[dp.cmap] colors = np.array([cmap(i) for i in np.linspace(0.1, 0.8, len(mcmp.name_list))]) sg_names, sg_units = mcmp.compute_subgraphs() # fraction of units with agreement > threshold @@ -242,7 +242,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): backend_kwargs["ncols"] = len(name_list) self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - cmap = plt.get_cmap(dp.cmap) + cmap = plt.colormaps[dp.cmap] colors = np.array([cmap(i) for i in np.linspace(0.1, 0.8, len(mcmp.name_list))]) sg_names, sg_units = mcmp.compute_subgraphs() # fraction of units with agreement > threshold diff --git a/src/spikeinterface/widgets/utils.py b/src/spikeinterface/widgets/utils.py index 29e6474ee9..9536941c07 100644 --- a/src/spikeinterface/widgets/utils.py +++ b/src/spikeinterface/widgets/utils.py @@ -76,7 +76,8 @@ def get_some_colors(keys, color_engine="auto", map_name="gist_ncar", format="RGB elif color_engine == "matplotlib": # some map have black or white at border so +10 margin = max(4, int(N * 0.08)) - cmap = plt.get_cmap(map_name, N + 2 * margin) + cmap = plt.colormaps[map_name].resampled(N + 2 * margin) + colors = [cmap(i + margin) for i, key in enumerate(keys)] elif color_engine == "colorsys": @@ -153,7 +154,7 @@ def array_to_image( num_channels = data.shape[1] spacing = int(num_channels * spatial_zoom[1] * row_spacing) - cmap = plt.get_cmap(colormap) + cmap = plt.colormaps[colormap] zoomed_data = zoom(data, spatial_zoom) num_timepoints_after_scaling, num_channels_after_scaling = zoomed_data.shape num_timepoints_per_row_after_scaling = int(np.min([num_timepoints_per_row, num_timepoints]) * spatial_zoom[0]) From cc48a409fcb65bbb9477116ff6a1a42be0395309 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 21 May 2024 17:41:56 +0200 Subject: [PATCH 07/23] Try to remove the mpl boundary to run tests --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 89ea05e5bf..bc04a1bcd5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,7 +95,7 @@ full = [ "scikit-learn", "networkx", "distinctipy", - "matplotlib<3.9", # See https://github.com/SpikeInterface/spikeinterface/issues/2863 + "matplotlib", "cuda-python; platform_system != 'Darwin'", "numba", ] From 4b67b2099b6a1a1682b33a656a25bd9fdc2c13a2 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 21 May 2024 20:52:31 +0200 Subject: [PATCH 08/23] mpl 3.6 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index bc04a1bcd5..d040a4a36b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,7 +95,7 @@ full = [ "scikit-learn", "networkx", "distinctipy", - "matplotlib", + "matplotlib>=3.6", # matplotlib.colormaps "cuda-python; platform_system != 'Darwin'", "numba", ] From c9b1a9c6a4a06a934778d3369c55bf0956f7100b Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Wed, 22 May 2024 08:55:40 -0400 Subject: [PATCH 09/23] add docs WIP --- .../sorters/internal/simplesorter.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/spikeinterface/sorters/internal/simplesorter.py b/src/spikeinterface/sorters/internal/simplesorter.py index 7004514ec7..402c34d543 100644 --- a/src/spikeinterface/sorters/internal/simplesorter.py +++ b/src/spikeinterface/sorters/internal/simplesorter.py @@ -46,6 +46,23 @@ class SimpleSorter(ComponentsBasedSorter): "job_kwargs": {"n_jobs": -1, "chunk_duration": "1s"}, } + _params_descriptions = { + "apply_preprocessing": "whether to apply the preprocessing steps, default: False", + "waveforms": "A dictonary containing waveforms params: ms_before (peak of spike) default: 1.0, ms_after (peak of spike) deafult: 1.5", + "filtering": "A dictionary containing bandpass filter conditions, freq_min' default: 300 and 'freq_max' default:8000.0", + "detection": ( + "A dictionary for specifying the detection conditions of 'peak_sign' (pos or neg) default: 'neg', " + "'detect_threshold' (snr) default: 5.0, 'exclude_sweep_ms' default: 1.5, 'radius_um' default: 150.0" + ), + "features": "A dictionary for the PCA specifying the 'n_components, default: 3", + "clustering": ( + "A dictionary for specifying the clustering parameters: 'method' (to cluster) default: 'hdbscan', " + "'min_cluster_size' (min number of spikes per cluster) default: 25, 'allow_single_cluster' default: True, " + " 'core_dist_n_jobs' (parallelization) default: -1, cluster_selection_method (for hdbscan) default: leaf" + ), + "job_kwargs": "Spikeinterface job_kwargs (see job_kwargs documentation) default 'n_jobs': -1, 'chunk_duration': '1s'", + } + @classmethod def get_sorter_version(cls): return "1.0" From 08360fc4c7367a4f3868250695393fb126128992 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Wed, 22 May 2024 08:58:38 -0400 Subject: [PATCH 10/23] oops --- src/spikeinterface/sorters/internal/simplesorter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/internal/simplesorter.py b/src/spikeinterface/sorters/internal/simplesorter.py index 402c34d543..7ead1b1626 100644 --- a/src/spikeinterface/sorters/internal/simplesorter.py +++ b/src/spikeinterface/sorters/internal/simplesorter.py @@ -46,7 +46,7 @@ class SimpleSorter(ComponentsBasedSorter): "job_kwargs": {"n_jobs": -1, "chunk_duration": "1s"}, } - _params_descriptions = { + _params_description = { "apply_preprocessing": "whether to apply the preprocessing steps, default: False", "waveforms": "A dictonary containing waveforms params: ms_before (peak of spike) default: 1.0, ms_after (peak of spike) deafult: 1.5", "filtering": "A dictionary containing bandpass filter conditions, freq_min' default: 300 and 'freq_max' default:8000.0", From 42deaa12844cc383b36a96307155829d9b9f0123 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 22 May 2024 21:48:33 +0200 Subject: [PATCH 11/23] Fix in matpltolib. --- src/spikeinterface/widgets/utils_matplotlib.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/widgets/utils_matplotlib.py b/src/spikeinterface/widgets/utils_matplotlib.py index 9eb5f275b2..825245750f 100644 --- a/src/spikeinterface/widgets/utils_matplotlib.py +++ b/src/spikeinterface/widgets/utils_matplotlib.py @@ -15,7 +15,7 @@ def make_mpl_figure(figure=None, ax=None, axes=None, ncols=None, num_axes=None, if "ipympl" not in matplotlib.get_backend(): ax = figure.add_subplot(111) else: - ax = figure.add_subplot(111, layout="constrained") + ax = figure.add_subplot(111) axes = np.array([[ax]]) else: assert ncols is not None From 7b48523d00f7039b8b278f8801a82254ef384027 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 22 May 2024 21:48:54 +0200 Subject: [PATCH 12/23] Bug in Analyzer when using zarr. --- src/spikeinterface/core/analyzer_extension_core.py | 6 +++--- src/spikeinterface/core/sortinganalyzer.py | 12 ++++-------- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index 83e035e84d..066194725d 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -488,9 +488,9 @@ def get_templates(self, unit_ids=None, operator="average", percentile=None, save self.params["operators"] += [(operator, percentile)] templates_array = self.data[key] - if save: - if not self.sorting_analyzer.is_read_only(): - self.save() + if save: + if not self.sorting_analyzer.is_read_only(): + self.save() if unit_ids is not None: unit_indices = self.sorting_analyzer.sorting.ids_to_indices(unit_ids) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 7a77dc28c8..5d1e586dea 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -585,9 +585,8 @@ def load_from_zarr(cls, folder, recording=None): rec_attributes["probegroup"] = None # sparsity - if "sparsity_mask" in zarr_root.attrs: - # sparsity = zarr_root.attrs["sparsity"] - sparsity = ChannelSparsity(zarr_root["sparsity_mask"], cls.unit_ids, rec_attributes["channel_ids"]) + if "sparsity_mask" in zarr_root: + sparsity = ChannelSparsity(np.array(zarr_root["sparsity_mask"]), sorting.unit_ids, rec_attributes["channel_ids"]) else: sparsity = None @@ -1596,10 +1595,6 @@ def load_data(self): self.data[ext_data_name] = ext_data elif self.format == "zarr": - # Alessio - # TODO: we need decide if we make a copy to memory or keep the lazy loading. For binary_folder it used to be lazy with memmap - # but this make the garbage complicated when a data is hold by a plot but the o SortingAnalyzer is delete - # lets talk extension_group = self._get_zarr_extension_group(mode="r") for ext_data_name in extension_group.keys(): ext_data_ = extension_group[ext_data_name] @@ -1615,7 +1610,8 @@ def load_data(self): elif "object" in ext_data_.attrs: ext_data = ext_data_[0] else: - ext_data = ext_data_ + # this load in memmory + ext_data = np.array(ext_data_) self.data[ext_data_name] = ext_data def copy(self, new_sorting_analyzer, unit_ids=None): From 9f7156248a2a3518f8de0d410240b1397ce4ce22 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 22 May 2024 19:50:12 +0000 Subject: [PATCH 13/23] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/sortinganalyzer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 5d1e586dea..d9fcf44442 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -586,7 +586,9 @@ def load_from_zarr(cls, folder, recording=None): # sparsity if "sparsity_mask" in zarr_root: - sparsity = ChannelSparsity(np.array(zarr_root["sparsity_mask"]), sorting.unit_ids, rec_attributes["channel_ids"]) + sparsity = ChannelSparsity( + np.array(zarr_root["sparsity_mask"]), sorting.unit_ids, rec_attributes["channel_ids"] + ) else: sparsity = None From 586fab652dffeaf3b33ed249c5dc518162227254 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 23 May 2024 09:59:14 +0200 Subject: [PATCH 14/23] improve analyzer tests --- src/spikeinterface/core/sparsity.py | 8 ++++++++ src/spikeinterface/core/tests/test_sortinganalyzer.py | 5 +++++ 2 files changed, 13 insertions(+) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 415ca42548..cefd7bd950 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -118,6 +118,14 @@ def __repr__(self): txt = f"ChannelSparsity - units: {self.num_units} - channels: {self.num_channels} - density, P(x=1): {density:0.2f}" return txt + def __eq__(self, other): + return ( + isinstance(other, ChannelSparsity) + and np.array_equal(self.channel_ids, other.channel_ids) + and np.array_equal(self.unit_ids, other.unit_ids) + and np.array_equal(self.mask, other.mask) + ) + @property def unit_id_to_channel_ids(self): if self._unit_id_to_channel_ids is None: diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index faed5161c6..13e01c32da 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -155,10 +155,15 @@ def _check_sorting_analyzers(sorting_analyzer, original_sorting, cache_folder): data = sorting_analyzer2.get_extension("dummy").data assert "result_one" in data + assert isinstance(data["result_one"], str) + assert isinstance(data["result_two"], np.ndarray) assert data["result_two"].size == original_sorting.to_spike_vector().size + assert np.array_equal(data["result_two"], sorting_analyzer.get_extension("dummy").data["result_two"]) assert sorting_analyzer2.return_scaled == sorting_analyzer.return_scaled + assert sorting_analyzer2.sparsity == sorting_analyzer.sparsity + # select unit_ids to several format for format in ("memory", "binary_folder", "zarr"): if format != "memory": From 5d9b1e0c94699d2e3c33d7e1aa5a351c850c3105 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 23 May 2024 10:43:08 +0200 Subject: [PATCH 15/23] Update src/spikeinterface/sorters/internal/simplesorter.py --- src/spikeinterface/sorters/internal/simplesorter.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sorters/internal/simplesorter.py b/src/spikeinterface/sorters/internal/simplesorter.py index 7ead1b1626..314c552d6d 100644 --- a/src/spikeinterface/sorters/internal/simplesorter.py +++ b/src/spikeinterface/sorters/internal/simplesorter.py @@ -48,8 +48,8 @@ class SimpleSorter(ComponentsBasedSorter): _params_description = { "apply_preprocessing": "whether to apply the preprocessing steps, default: False", - "waveforms": "A dictonary containing waveforms params: ms_before (peak of spike) default: 1.0, ms_after (peak of spike) deafult: 1.5", - "filtering": "A dictionary containing bandpass filter conditions, freq_min' default: 300 and 'freq_max' default:8000.0", + "waveforms": "A dictonary containing waveforms params: 'ms_before' (peak of spike) default: 1.0, 'ms_after' (peak of spike) deafult: 1.5", + "filtering": "A dictionary containing bandpass filter conditions, 'freq_min' default: 300 and 'freq_max' default:8000.0", "detection": ( "A dictionary for specifying the detection conditions of 'peak_sign' (pos or neg) default: 'neg', " "'detect_threshold' (snr) default: 5.0, 'exclude_sweep_ms' default: 1.5, 'radius_um' default: 150.0" From 79d883f1401a2731bc2bf5e9e6e60f55f35377b4 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 23 May 2024 12:23:53 +0200 Subject: [PATCH 16/23] remove verbose from job_kwargs --- src/spikeinterface/core/job_tools.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index fa79a8ce01..9cf22563d7 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -47,7 +47,6 @@ "chunk_duration", "progress_bar", "mp_context", - "verbose", "max_threads_per_process", ) From 4ff43c1c0012669650165b0dab2873ac83edac28 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 23 May 2024 12:40:35 +0200 Subject: [PATCH 17/23] Propagate remove verbose from job_kwargs to write_binary_recording() --- src/spikeinterface/core/recording_tools.py | 5 ++++- .../core/tests/test_recording_tools.py | 16 ++++++++-------- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index e698302ee1..5974e69e46 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -73,6 +73,7 @@ def write_binary_recording( add_file_extension: bool = True, byte_offset: int = 0, auto_cast_uint: bool = True, + verbose : bool = True, **job_kwargs, ): """ @@ -98,6 +99,8 @@ def write_binary_recording( auto_cast_uint: bool, default: True If True, unsigned integers are automatically cast to int if the specified dtype is signed .. deprecated:: 0.103, use the `unsigned_to_signed` function instead. + verbose: bool + If True, output is verbose {} """ job_kwargs = fix_job_kwargs(job_kwargs) @@ -138,7 +141,7 @@ def write_binary_recording( init_func = _init_binary_worker init_args = (recording, file_path_dict, dtype, byte_offset, cast_unsigned) executor = ChunkRecordingExecutor( - recording, func, init_func, init_args, job_name="write_binary_recording", **job_kwargs + recording, func, init_func, init_args, job_name="write_binary_recording", verbose=verbose, **job_kwargs ) executor.run() diff --git a/src/spikeinterface/core/tests/test_recording_tools.py b/src/spikeinterface/core/tests/test_recording_tools.py index 5e0b77a151..7a92846df8 100644 --- a/src/spikeinterface/core/tests/test_recording_tools.py +++ b/src/spikeinterface/core/tests/test_recording_tools.py @@ -37,8 +37,8 @@ def test_write_binary_recording(tmp_path): file_paths = [tmp_path / "binary01.raw"] # Write binary recording - job_kwargs = dict(verbose=False, n_jobs=1) - write_binary_recording(recording, file_paths=file_paths, dtype=dtype, **job_kwargs) + job_kwargs = dict(n_jobs=1) + write_binary_recording(recording, file_paths=file_paths, dtype=dtype, verbose=False, **job_kwargs) # Check if written data matches original data recorder_binary = BinaryRecordingExtractor( @@ -64,9 +64,9 @@ def test_write_binary_recording_offset(tmp_path): file_paths = [tmp_path / "binary01.raw"] # Write binary recording - job_kwargs = dict(verbose=False, n_jobs=1) + job_kwargs = dict(n_jobs=1) byte_offset = 125 - write_binary_recording(recording, file_paths=file_paths, dtype=dtype, byte_offset=byte_offset, **job_kwargs) + write_binary_recording(recording, file_paths=file_paths, dtype=dtype, byte_offset=byte_offset, verbose=False, **job_kwargs) # Check if written data matches original data recorder_binary = BinaryRecordingExtractor( @@ -97,8 +97,8 @@ def test_write_binary_recording_parallel(tmp_path): file_paths = [tmp_path / "binary01.raw", tmp_path / "binary02.raw"] # Write binary recording - job_kwargs = dict(verbose=False, n_jobs=2, chunk_memory="100k", mp_context="spawn") - write_binary_recording(recording, file_paths=file_paths, dtype=dtype, **job_kwargs) + job_kwargs = dict(n_jobs=2, chunk_memory="100k", mp_context="spawn") + write_binary_recording(recording, file_paths=file_paths, dtype=dtype, verbose=False, **job_kwargs) # Check if written data matches original data recorder_binary = BinaryRecordingExtractor( @@ -127,8 +127,8 @@ def test_write_binary_recording_multiple_segment(tmp_path): file_paths = [tmp_path / "binary01.raw", tmp_path / "binary02.raw"] # Write binary recording - job_kwargs = dict(verbose=False, n_jobs=2, chunk_memory="100k", mp_context="spawn") - write_binary_recording(recording, file_paths=file_paths, dtype=dtype, **job_kwargs) + job_kwargs = dict(n_jobs=2, chunk_memory="100k", mp_context="spawn") + write_binary_recording(recording, file_paths=file_paths, dtype=dtype, verbose=False, **job_kwargs) # Check if written data matches original data recorder_binary = BinaryRecordingExtractor( From cd4b115d08ddf576a60917c04647912bb9177e54 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 23 May 2024 14:48:10 +0200 Subject: [PATCH 18/23] Still remove some verbose --- src/spikeinterface/sortingcomponents/clustering/circus.py | 3 ++- .../sortingcomponents/clustering/position_and_features.py | 2 +- .../sortingcomponents/clustering/position_and_pca.py | 2 +- .../sortingcomponents/clustering/position_ptp_scaled.py | 2 +- .../sortingcomponents/clustering/random_projections.py | 3 ++- .../sortingcomponents/clustering/sliding_hdbscan.py | 2 +- .../sortingcomponents/tests/test_clustering.py | 5 +++-- .../sortingcomponents/tests/test_peak_detection.py | 2 +- .../sortingcomponents/tests/test_peak_selection.py | 3 +-- 9 files changed, 13 insertions(+), 11 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 06d2f1f6db..ac9625cfd5 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -63,6 +63,7 @@ class CircusClustering: "noise_levels": None, "tmp_folder": None, "job_kwargs": {}, + "verbose": True, } @classmethod @@ -72,7 +73,7 @@ def main_function(cls, recording, peaks, params): job_kwargs = fix_job_kwargs(params["job_kwargs"]) d = params - verbose = job_kwargs.get("verbose", True) + verbose = d["verbose"] fs = recording.get_sampling_frequency() ms_before = params["ms_before"] diff --git a/src/spikeinterface/sortingcomponents/clustering/position_and_features.py b/src/spikeinterface/sortingcomponents/clustering/position_and_features.py index a07a6140e1..d23eb26239 100644 --- a/src/spikeinterface/sortingcomponents/clustering/position_and_features.py +++ b/src/spikeinterface/sortingcomponents/clustering/position_and_features.py @@ -42,7 +42,7 @@ class PositionAndFeaturesClustering: "ms_before": 1.5, "ms_after": 1.5, "cleaning_method": "dip", - "job_kwargs": {"n_jobs": -1, "chunk_memory": "10M", "verbose": True, "progress_bar": True}, + "job_kwargs": {"n_jobs": -1, "chunk_memory": "10M", "progress_bar": True}, } @classmethod diff --git a/src/spikeinterface/sortingcomponents/clustering/position_and_pca.py b/src/spikeinterface/sortingcomponents/clustering/position_and_pca.py index 0b1b8cc742..4dfe3c960c 100644 --- a/src/spikeinterface/sortingcomponents/clustering/position_and_pca.py +++ b/src/spikeinterface/sortingcomponents/clustering/position_and_pca.py @@ -38,7 +38,7 @@ class PositionAndPCAClustering: "ms_after": 2.5, "n_components_by_channel": 3, "n_components": 5, - "job_kwargs": {"n_jobs": -1, "chunk_memory": "10M", "verbose": True, "progress_bar": True}, + "job_kwargs": {"n_jobs": -1, "chunk_memory": "10M", "progress_bar": True}, "hdbscan_global_kwargs": {"min_cluster_size": 20, "allow_single_cluster": True, "core_dist_n_jobs": -1}, "hdbscan_local_kwargs": {"min_cluster_size": 20, "allow_single_cluster": True, "core_dist_n_jobs": -1}, "waveform_mode": "shared_memory", diff --git a/src/spikeinterface/sortingcomponents/clustering/position_ptp_scaled.py b/src/spikeinterface/sortingcomponents/clustering/position_ptp_scaled.py index 2195362543..788addf1e6 100644 --- a/src/spikeinterface/sortingcomponents/clustering/position_ptp_scaled.py +++ b/src/spikeinterface/sortingcomponents/clustering/position_ptp_scaled.py @@ -26,7 +26,7 @@ class PositionPTPScaledClustering: "ptps": None, "scales": (1, 1, 10), "peak_localization_kwargs": {"method": "center_of_mass"}, - "job_kwargs": {"n_jobs": -1, "chunk_memory": "10M", "verbose": True, "progress_bar": True}, + "job_kwargs": {"n_jobs": -1, "chunk_memory": "10M", "progress_bar": True}, "hdbscan_kwargs": { "min_cluster_size": 20, "min_samples": 20, diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 6c1ad75383..42573962a5 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -56,6 +56,7 @@ class RandomProjectionClustering: "smoothing_kwargs": {"window_length_ms": 0.25}, "tmp_folder": None, "job_kwargs": {}, + "verbose": True, } @classmethod @@ -65,7 +66,7 @@ def main_function(cls, recording, peaks, params): job_kwargs = fix_job_kwargs(params["job_kwargs"]) d = params - verbose = job_kwargs.get("verbose", True) + verbose = d["verbose"] fs = recording.get_sampling_frequency() radius_um = params["radius_um"] diff --git a/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py b/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py index 7e7a8de1d7..7528c696d4 100644 --- a/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py +++ b/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py @@ -55,7 +55,7 @@ class SlidingHdbscanClustering: "auto_merge_quantile_limit": 0.8, "ratio_num_channel_intersect": 0.5, # ~ 'auto_trash_misalignment_shift' : 4, - "job_kwargs": {"n_jobs": -1, "chunk_memory": "10M", "verbose": True, "progress_bar": True}, + "job_kwargs": {"n_jobs": -1, "chunk_memory": "10M", "progress_bar": True}, } @classmethod diff --git a/src/spikeinterface/sortingcomponents/tests/test_clustering.py b/src/spikeinterface/sortingcomponents/tests/test_clustering.py index 427d120c5d..76a9d8a85e 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_clustering.py +++ b/src/spikeinterface/sortingcomponents/tests/test_clustering.py @@ -13,7 +13,7 @@ def job_kwargs(): - return dict(n_jobs=1, chunk_size=10000, progress_bar=True, verbose=True, mp_context="spawn") + return dict(n_jobs=1, chunk_size=10000, progress_bar=True, mp_context="spawn") @pytest.fixture(name="job_kwargs", scope="module") @@ -78,6 +78,7 @@ def test_find_cluster_from_peaks(clustering_method, recording, peaks, peak_locat peak_locations = run_peak_locations(recording, peaks, job_kwargs) # method = "position_and_pca" # method = "circus" - method = "tdc_clustering" + # method = "tdc_clustering" + method = "random_projections" test_find_cluster_from_peaks(method, recording, peaks, peak_locations) diff --git a/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py b/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py index d36c59dc69..2ecccb421c 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py +++ b/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py @@ -60,7 +60,7 @@ def sorting(dataset): def job_kwargs(): - return dict(n_jobs=1, chunk_size=10000, progress_bar=True, verbose=True, mp_context="spawn") + return dict(n_jobs=1, chunk_size=10000, progress_bar=True, mp_context="spawn") @pytest.fixture(name="job_kwargs", scope="module") diff --git a/src/spikeinterface/sortingcomponents/tests/test_peak_selection.py b/src/spikeinterface/sortingcomponents/tests/test_peak_selection.py index 4326f21512..d133a0f9d2 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_peak_selection.py +++ b/src/spikeinterface/sortingcomponents/tests/test_peak_selection.py @@ -23,13 +23,12 @@ def test_select_peaks(): detect_threshold=5, exclude_sweep_ms=0.1, chunk_size=10000, - verbose=1, progress_bar=False, noise_levels=noise_levels, ) peak_locations = localize_peaks( - recording, peaks, method="center_of_mass", n_jobs=2, chunk_size=10000, verbose=True, progress_bar=True + recording, peaks, method="center_of_mass", n_jobs=2, chunk_size=10000, progress_bar=True ) n_peaks = 100 From f70633427836d461eb7388d067f9b3215346068d Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 23 May 2024 16:02:22 +0200 Subject: [PATCH 19/23] still remove verbose in job_kwargs --- .../sortingcomponents/tests/test_motion_estimation.py | 8 ++++---- .../sortingcomponents/tests/test_peak_localization.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py b/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py index 36d2d34f4d..7ed23b2d32 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py @@ -47,7 +47,6 @@ def setup_module(): detect_threshold=5, exclude_sweep_ms=0.1, chunk_size=10000, - verbose=1, progress_bar=True, pipeline_nodes=pipeline_nodes, ) @@ -156,12 +155,13 @@ def test_estimate_motion(): bin_um=10.0, margin_um=5, output_extra_check=True, - progress_bar=False, - verbose=False, + ) kwargs.update(cases_kwargs) + + job_kwargs = dict(progress_bar=False) - motion, temporal_bins, spatial_bins, extra_check = estimate_motion(recording, peaks, peak_locations, **kwargs) + motion, temporal_bins, spatial_bins, extra_check = estimate_motion(recording, peaks, peak_locations, **kwargs, **job_kwargs) motions[name] = motion diff --git a/src/spikeinterface/sortingcomponents/tests/test_peak_localization.py b/src/spikeinterface/sortingcomponents/tests/test_peak_localization.py index 33d45af6c4..a10a81ec80 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_peak_localization.py +++ b/src/spikeinterface/sortingcomponents/tests/test_peak_localization.py @@ -10,8 +10,8 @@ def test_localize_peaks(): recording, _ = make_dataset() - # job_kwargs = dict(n_jobs=2, chunk_size=10000, verbose=False, progress_bar=True) - job_kwargs = dict(n_jobs=1, chunk_size=10000, verbose=False, progress_bar=True) + # job_kwargs = dict(n_jobs=2, chunk_size=10000, progress_bar=True) + job_kwargs = dict(n_jobs=1, chunk_size=10000, progress_bar=True) peaks = detect_peaks( recording, method="locally_exclusive", peak_sign="neg", detect_threshold=5, exclude_sweep_ms=0.1, **job_kwargs From c0d3d084b93a10813e24cbff58847c2700e5e48e Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 23 May 2024 16:32:33 +0200 Subject: [PATCH 20/23] still remove verbose at wrong place --- .../sortingcomponents/clustering/random_projections.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 42573962a5..77d47aec16 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -162,7 +162,6 @@ def main_function(cls, recording, peaks, params): cleaning_matching_params[value] = None cleaning_matching_params["chunk_duration"] = "100ms" cleaning_matching_params["n_jobs"] = 1 - cleaning_matching_params["verbose"] = False cleaning_matching_params["progress_bar"] = False cleaning_params = params["cleaning_kwargs"].copy() From 8d39956fdb9b3b2e2e0b2bd65f32a368f13bed76 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 23 May 2024 16:34:34 +0200 Subject: [PATCH 21/23] oups --- src/spikeinterface/sortingcomponents/clustering/circus.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index ac9625cfd5..c9aaee1329 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -251,7 +251,6 @@ def main_function(cls, recording, peaks, params): cleaning_matching_params.pop(value) cleaning_matching_params["chunk_duration"] = "100ms" cleaning_matching_params["n_jobs"] = 1 - cleaning_matching_params["verbose"] = False cleaning_matching_params["progress_bar"] = False cleaning_params = params["cleaning_kwargs"].copy() From 63b7e7c3b31456cbb84c34b8f501f10d4a8c415c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 23 May 2024 14:40:41 +0000 Subject: [PATCH 22/23] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/recording_tools.py | 2 +- src/spikeinterface/core/tests/test_recording_tools.py | 4 +++- .../sortingcomponents/tests/test_motion_estimation.py | 7 ++++--- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index 5974e69e46..8f9e67c954 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -73,7 +73,7 @@ def write_binary_recording( add_file_extension: bool = True, byte_offset: int = 0, auto_cast_uint: bool = True, - verbose : bool = True, + verbose: bool = True, **job_kwargs, ): """ diff --git a/src/spikeinterface/core/tests/test_recording_tools.py b/src/spikeinterface/core/tests/test_recording_tools.py index 7a92846df8..d83e4d76fc 100644 --- a/src/spikeinterface/core/tests/test_recording_tools.py +++ b/src/spikeinterface/core/tests/test_recording_tools.py @@ -66,7 +66,9 @@ def test_write_binary_recording_offset(tmp_path): # Write binary recording job_kwargs = dict(n_jobs=1) byte_offset = 125 - write_binary_recording(recording, file_paths=file_paths, dtype=dtype, byte_offset=byte_offset, verbose=False, **job_kwargs) + write_binary_recording( + recording, file_paths=file_paths, dtype=dtype, byte_offset=byte_offset, verbose=False, **job_kwargs + ) # Check if written data matches original data recorder_binary = BinaryRecordingExtractor( diff --git a/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py b/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py index 7ed23b2d32..36f623ebf8 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py @@ -155,13 +155,14 @@ def test_estimate_motion(): bin_um=10.0, margin_um=5, output_extra_check=True, - ) kwargs.update(cases_kwargs) - + job_kwargs = dict(progress_bar=False) - motion, temporal_bins, spatial_bins, extra_check = estimate_motion(recording, peaks, peak_locations, **kwargs, **job_kwargs) + motion, temporal_bins, spatial_bins, extra_check = estimate_motion( + recording, peaks, peak_locations, **kwargs, **job_kwargs + ) motions[name] = motion From 7fee4033bd5b3c10001f59a5bef41b49f89b8ef0 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 23 May 2024 16:46:47 +0200 Subject: [PATCH 23/23] More verbose to remove --- src/spikeinterface/sorters/internal/simplesorter.py | 2 +- src/spikeinterface/sorters/internal/spyking_circus2.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sorters/internal/simplesorter.py b/src/spikeinterface/sorters/internal/simplesorter.py index 314c552d6d..199352ab73 100644 --- a/src/spikeinterface/sorters/internal/simplesorter.py +++ b/src/spikeinterface/sorters/internal/simplesorter.py @@ -71,7 +71,7 @@ def get_sorter_version(cls): def _run_from_folder(cls, sorter_output_folder, params, verbose): job_kwargs = params["job_kwargs"] job_kwargs = fix_job_kwargs(job_kwargs) - job_kwargs.update({"verbose": verbose, "progress_bar": verbose}) + job_kwargs.update({"progress_bar": verbose}) from spikeinterface.sortingcomponents.peak_detection import detect_peaks from spikeinterface.sortingcomponents.tools import extract_waveform_at_max_channel diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 2af28fb179..c1021e787a 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -113,7 +113,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): job_kwargs = params["job_kwargs"] job_kwargs = fix_job_kwargs(job_kwargs) - job_kwargs.update({"verbose": verbose, "progress_bar": verbose}) + job_kwargs.update({"progress_bar": verbose}) recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False)