From 12fd197859a3bb91099e9f5fb73fc5f74f923847 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 19 Sep 2023 12:56:55 +0200 Subject: [PATCH 01/26] Use sparsity mask and handle right border correctly --- .../postprocessing/amplitude_scalings.py | 54 +++++++++---------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 5a0148c5c4..4dab68fdf8 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -90,10 +90,7 @@ def _run(self, **job_kwargs): if self._params["max_dense_channels"] is not None: assert recording.get_num_channels() <= self._params["max_dense_channels"], "" sparsity = ChannelSparsity.create_dense(we) - sparsity_inds = sparsity.unit_id_to_channel_indices - - # easier to use in chunk function as spikes use unit_index instead o id - unit_inds_to_channel_indices = {unit_ind: sparsity_inds[unit_id] for unit_ind, unit_id in enumerate(unit_ids)} + sparsity_mask = sparsity.mask all_templates = we.get_all_templates() # precompute segment slice @@ -113,7 +110,7 @@ def _run(self, **job_kwargs): self.spikes, all_templates, segment_slices, - unit_inds_to_channel_indices, + sparsity_mask, nbefore, nafter, cut_out_before, @@ -262,7 +259,7 @@ def _init_worker_amplitude_scalings( spikes, all_templates, segment_slices, - unit_inds_to_channel_indices, + sparsity_mask, nbefore, nafter, cut_out_before, @@ -282,7 +279,7 @@ def _init_worker_amplitude_scalings( worker_ctx["cut_out_before"] = cut_out_before worker_ctx["cut_out_after"] = cut_out_after worker_ctx["return_scaled"] = return_scaled - worker_ctx["unit_inds_to_channel_indices"] = unit_inds_to_channel_indices + worker_ctx["sparsity_mask"] = sparsity_mask worker_ctx["handle_collisions"] = handle_collisions worker_ctx["delta_collision_samples"] = delta_collision_samples @@ -306,7 +303,7 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) recording = worker_ctx["recording"] all_templates = worker_ctx["all_templates"] segment_slices = worker_ctx["segment_slices"] - unit_inds_to_channel_indices = worker_ctx["unit_inds_to_channel_indices"] + sparsity_mask = worker_ctx["sparsity_mask"] nbefore = worker_ctx["nbefore"] cut_out_before = worker_ctx["cut_out_before"] cut_out_after = worker_ctx["cut_out_after"] @@ -339,7 +336,7 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) i1_margin = np.searchsorted(spikes_in_segment["sample_index"], end_frame + right) local_spikes_w_margin = spikes_in_segment[i0_margin:i1_margin] collisions_local = find_collisions( - local_spikes, local_spikes_w_margin, delta_collision_samples, unit_inds_to_channel_indices + local_spikes, local_spikes_w_margin, delta_collision_samples, sparsity_mask ) else: collisions_local = {} @@ -354,7 +351,7 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) continue unit_index = spike["unit_index"] sample_index = spike["sample_index"] - sparse_indices = unit_inds_to_channel_indices[unit_index] + sparse_indices = sparsity_mask[unit_index] template = all_templates[unit_index][:, sparse_indices] template = template[nbefore - cut_out_before : nbefore + cut_out_after] sample_centered = sample_index - start_frame @@ -393,7 +390,7 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) right, nbefore, all_templates, - unit_inds_to_channel_indices, + sparsity_mask, cut_out_before, cut_out_after, ) @@ -410,14 +407,14 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) ### Collision handling ### -def _are_unit_indices_overlapping(unit_inds_to_channel_indices, i, j): +def _are_unit_indices_overlapping(sparsity_mask, i, j): """ Returns True if the unit indices i and j are overlapping, False otherwise Parameters ---------- - unit_inds_to_channel_indices: dict - A dictionary mapping unit indices to channel indices + sparsity_mask: boolean mask + The sparsity mask i: int The first unit index j: int @@ -428,13 +425,13 @@ def _are_unit_indices_overlapping(unit_inds_to_channel_indices, i, j): bool True if the unit indices i and j are overlapping, False otherwise """ - if len(np.intersect1d(unit_inds_to_channel_indices[i], unit_inds_to_channel_indices[j])) > 0: + if np.sum(np.logical_and(sparsity_mask[i], sparsity_mask[j])) > 0: return True else: return False -def find_collisions(spikes, spikes_w_margin, delta_collision_samples, unit_inds_to_channel_indices): +def find_collisions(spikes, spikes_w_margin, delta_collision_samples, sparsity_mask): """ Finds the collisions between spikes. @@ -446,8 +443,8 @@ def find_collisions(spikes, spikes_w_margin, delta_collision_samples, unit_inds_ An array of spikes within the added margin delta_collision_samples: int The maximum number of samples between two spikes to consider them as overlapping - unit_inds_to_channel_indices: dict - A dictionary mapping unit indices to channel indices + sparsity_mask: boolean mask + The sparsity mask Returns ------- @@ -480,7 +477,7 @@ def find_collisions(spikes, spikes_w_margin, delta_collision_samples, unit_inds_ # find the overlapping spikes in space as well for possible_overlapping_spike_index in possible_overlapping_spike_indices: if _are_unit_indices_overlapping( - unit_inds_to_channel_indices, + sparsity_mask, spike["unit_index"], spikes_w_margin[possible_overlapping_spike_index]["unit_index"], ): @@ -501,7 +498,7 @@ def fit_collision( right, nbefore, all_templates, - unit_inds_to_channel_indices, + sparsity_mask, cut_out_before, cut_out_after, ): @@ -528,8 +525,8 @@ def fit_collision( The number of samples before the spike to consider for the fit. all_templates: np.ndarray A numpy array of shape (n_units, n_samples, n_channels) containing the templates. - unit_inds_to_channel_indices: dict - A dictionary mapping unit indices to channel indices. + sparsity_mask: boolean mask + The sparsity mask cut_out_before: int The number of samples to cut out before the spike. cut_out_after: int @@ -547,14 +544,15 @@ def fit_collision( sample_last_centered = np.max(collision["sample_index"]) - (start_frame - left) # construct sparsity as union between units' sparsity - sparse_indices = np.array([], dtype="int") + sparse_indices = np.zeros(sparsity_mask.shape[1], dtype="int") for spike in collision: - sparse_indices_i = unit_inds_to_channel_indices[spike["unit_index"]] - sparse_indices = np.union1d(sparse_indices, sparse_indices_i) + sparse_indices_i = sparsity_mask[spike["unit_index"]] + sparse_indices = np.logical_or(sparse_indices, sparse_indices_i) local_waveform_start = max(0, sample_first_centered - cut_out_before) local_waveform_end = min(traces_with_margin.shape[0], sample_last_centered + cut_out_after) local_waveform = traces_with_margin[local_waveform_start:local_waveform_end, sparse_indices] + num_samples_local_waveform = local_waveform.shape[0] y = local_waveform.T.flatten() X = np.zeros((len(y), len(collision))) @@ -567,8 +565,10 @@ def fit_collision( # deal with borders if sample_centered - cut_out_before < 0: full_template[: sample_centered + cut_out_after] = template_cut[cut_out_before - sample_centered :] - elif sample_centered + cut_out_after > end_frame + right: - full_template[sample_centered - cut_out_before :] = template_cut[: -cut_out_after - (end_frame + right)] + elif sample_centered + cut_out_after > num_samples_local_waveform: + full_template[sample_centered - cut_out_before :] = template_cut[ + : -(cut_out_after + sample_centered - num_samples_local_waveform) + ] else: full_template[sample_centered - cut_out_before : sample_centered + cut_out_after] = template_cut X[:, i] = full_template.T.flatten() From e964731b33401db1757ce813d2078c00a36dcf34 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 21 Sep 2023 16:36:31 +0200 Subject: [PATCH 02/26] Start refactor ipywidgets plot_traces --- src/spikeinterface/widgets/traces.py | 29 +- .../widgets/utils_ipywidgets.py | 251 ++++++++++++++++-- 2 files changed, 254 insertions(+), 26 deletions(-) diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index 7bb2126744..c6e36387f8 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -276,11 +276,16 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt import ipywidgets.widgets as widgets from IPython.display import display + import ipywidgets.widgets as W from .utils_ipywidgets import ( check_ipywidget_backend, make_timeseries_controller, make_channel_controller, make_scale_controller, + + TimeSlider, + ScaleWidget, + ) check_ipywidget_backend() @@ -308,6 +313,8 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): t_start = 0.0 t_stop = rec0.get_num_samples(segment_index=0) / rec0.get_sampling_frequency() + + ts_widget, ts_controller = make_timeseries_controller( t_start, t_stop, @@ -319,6 +326,22 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): width_cm, ) + # some widgets + self.time_slider = TimeSlider( + durations=[rec0.get_duration(s) for s in range(rec0.get_num_segments())], + sampling_frequency=rec0.sampling_frequency, + ) + self.layer_selector = W.Dropdown(description="layer", options=data_plot["layer_keys"], + layout=W.Layout(width="5cm"),) + self.mode_selector = W.Dropdown(options=["line", "map"], description="mode", value=data_plot["mode"], + layout=W.Layout(width="5cm"),) + self.scaler = ScaleWidget() + left_sidebar = W.VBox( + children=[self.layer_selector, self.mode_selector, self.scaler], + layout=W.Layout(width="5cm"), + ) + + ch_widget, ch_controller = make_channel_controller(rec0, width_cm=ratios[2] * width_cm, height_cm=height_cm) scale_widget, scale_controller = make_scale_controller(width_cm=ratios[0] * width_cm, height_cm=height_cm) @@ -346,8 +369,10 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.widget = widgets.AppLayout( center=self.figure.canvas, - footer=ts_widget, - left_sidebar=scale_widget, + # footer=ts_widget, + footer=self.time_slider, + # left_sidebar=scale_widget, + left_sidebar = left_sidebar, right_sidebar=ch_widget, pane_heights=[0, 6, 1], pane_widths=ratios, diff --git a/src/spikeinterface/widgets/utils_ipywidgets.py b/src/spikeinterface/widgets/utils_ipywidgets.py index a7c571d1f0..674a2d2cc7 100644 --- a/src/spikeinterface/widgets/utils_ipywidgets.py +++ b/src/spikeinterface/widgets/utils_ipywidgets.py @@ -1,4 +1,6 @@ -import ipywidgets.widgets as widgets +import ipywidgets.widgets as W +import traitlets + import numpy as np @@ -10,20 +12,20 @@ def check_ipywidget_backend(): def make_timeseries_controller(t_start, t_stop, layer_keys, num_segments, time_range, mode, all_layers, width_cm): - time_slider = widgets.FloatSlider( + time_slider = W.FloatSlider( orientation="horizontal", description="time:", value=time_range[0], min=t_start, max=t_stop, continuous_update=False, - layout=widgets.Layout(width=f"{width_cm}cm"), + layout=W.Layout(width=f"{width_cm}cm"), ) - layer_selector = widgets.Dropdown(description="layer", options=layer_keys) - segment_selector = widgets.Dropdown(description="segment", options=list(range(num_segments))) - window_sizer = widgets.BoundedFloatText(value=np.diff(time_range)[0], step=0.1, min=0.005, description="win (s)") - mode_selector = widgets.Dropdown(options=["line", "map"], description="mode", value=mode) - all_layers = widgets.Checkbox(description="plot all layers", value=all_layers) + layer_selector = W.Dropdown(description="layer", options=layer_keys) + segment_selector = W.Dropdown(description="segment", options=list(range(num_segments))) + window_sizer = W.BoundedFloatText(value=np.diff(time_range)[0], step=0.1, min=0.005, description="win (s)") + mode_selector = W.Dropdown(options=["line", "map"], description="mode", value=mode) + all_layers = W.Checkbox(description="plot all layers", value=all_layers) controller = { "layer_key": layer_selector, @@ -33,32 +35,32 @@ def make_timeseries_controller(t_start, t_stop, layer_keys, num_segments, time_r "mode": mode_selector, "all_layers": all_layers, } - widget = widgets.VBox( - [time_slider, widgets.HBox([all_layers, layer_selector, segment_selector, window_sizer, mode_selector])] + widget = W.VBox( + [time_slider, W.HBox([all_layers, layer_selector, segment_selector, window_sizer, mode_selector])] ) return widget, controller def make_unit_controller(unit_ids, all_unit_ids, width_cm, height_cm): - unit_label = widgets.Label(value="units:") + unit_label = W.Label(value="units:") - unit_selector = widgets.SelectMultiple( + unit_selector = W.SelectMultiple( options=all_unit_ids, value=list(unit_ids), disabled=False, - layout=widgets.Layout(width=f"{width_cm}cm", height=f"{height_cm}cm"), + layout=W.Layout(width=f"{width_cm}cm", height=f"{height_cm}cm"), ) controller = {"unit_ids": unit_selector} - widget = widgets.VBox([unit_label, unit_selector]) + widget = W.VBox([unit_label, unit_selector]) return widget, controller def make_channel_controller(recording, width_cm, height_cm): - channel_label = widgets.Label("channel indices:", layout=widgets.Layout(justify_content="center")) - channel_selector = widgets.IntRangeSlider( + channel_label = W.Label("channel indices:", layout=W.Layout(justify_content="center")) + channel_selector = W.IntRangeSlider( value=[0, recording.get_num_channels()], min=0, max=recording.get_num_channels(), @@ -68,37 +70,238 @@ def make_channel_controller(recording, width_cm, height_cm): orientation="vertical", readout=True, readout_format="d", - layout=widgets.Layout(width=f"{0.8 * width_cm}cm", height=f"{height_cm}cm"), + layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{height_cm}cm"), ) controller = {"channel_inds": channel_selector} - widget = widgets.VBox([channel_label, channel_selector]) + widget = W.VBox([channel_label, channel_selector]) return widget, controller def make_scale_controller(width_cm, height_cm): - scale_label = widgets.Label("Scale", layout=widgets.Layout(justify_content="center")) + scale_label = W.Label("Scale", layout=W.Layout(justify_content="center")) - plus_selector = widgets.Button( + plus_selector = W.Button( description="", disabled=False, button_style="", # 'success', 'info', 'warning', 'danger' or '' tooltip="Increase scale", icon="arrow-up", - layout=widgets.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), + layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), ) - minus_selector = widgets.Button( + minus_selector = W.Button( description="", disabled=False, button_style="", # 'success', 'info', 'warning', 'danger' or '' tooltip="Decrease scale", icon="arrow-down", - layout=widgets.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), + layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), ) controller = {"plus": plus_selector, "minus": minus_selector} - widget = widgets.VBox([scale_label, plus_selector, minus_selector]) + widget = W.VBox([scale_label, plus_selector, minus_selector]) return widget, controller + + + +class TimeSlider(W.HBox): + + position = traitlets.Tuple(traitlets.Int(), traitlets.Int(), traitlets.Int()) + + def __init__(self, durations, sampling_frequency, time_range=(0, 1.), **kwargs): + + + self.num_segments = len(durations) + self.frame_limits = [int(sampling_frequency * d) for d in durations] + self.sampling_frequency = sampling_frequency + start_frame = int(time_range[0] * sampling_frequency) + end_frame = int(time_range[1] * sampling_frequency) + + self.frame_range = (start_frame, end_frame) + + self.segment_index = 0 + self.position = (start_frame, end_frame, self.segment_index) + + + layout = W.Layout(align_items="center", width="1.5cm", height="100%") + but_left = W.Button(description='', disabled=False, button_style='', icon='arrow-left', layout=layout) + but_right = W.Button(description='', disabled=False, button_style='', icon='arrow-right', layout=layout) + + but_left.on_click(self.move_left) + but_right.on_click(self.move_right) + + self.move_size = W.Dropdown(options=['10 ms', '100 ms', '1 s', '10 s', '1 m', '30 m', '1 h',], # '6 h', '24 h' + value='1 s', + description='', + layout = W.Layout(width="2cm") + ) + + # DatetimePicker is only for ipywidget v8 (which is not working in vscode 2023-03) + self.time_label = W.Text(value=f'{time_range[0]}',description='', + disabled=False, layout=W.Layout(width='5.5cm')) + self.time_label.observe(self.time_label_changed, names='value', type="change") + + + self.slider = W.IntSlider( + orientation='horizontal', + # description='time:', + value=start_frame, + min=0, + max=self.frame_limits[self.segment_index], + readout=False, + continuous_update=False, + layout=W.Layout(width=f'70%') + ) + + self.slider.observe(self.slider_moved, names='value', type="change") + + delta_s = np.diff(self.frame_range) / sampling_frequency + + self.window_sizer = W.BoundedFloatText(value=delta_s, step=1, + min=0.01, max=30., + description='win (s)', + layout=W.Layout(width='auto') + # layout=W.Layout(width=f'10%') + ) + self.window_sizer.observe(self.win_size_changed, names='value', type="change") + + self.segment_selector = W.Dropdown(description="segment", options=list(range(self.num_segments))) + self.segment_selector.observe(self.segment_changed, names='value', type="change") + + super(W.HBox, self).__init__(children=[self.segment_selector, but_left, self.move_size, but_right, + self.slider, self.time_label, self.window_sizer], + layout=W.Layout(align_items="center", width="100%", height="100%"), + **kwargs) + + self.observe(self.position_changed, names=['position'], type="change") + + def position_changed(self, change=None): + + self.unobserve(self.position_changed, names=['position'], type="change") + + start, stop, seg_index = self.position + if seg_index < 0 or seg_index >= self.num_segments: + self.position = change['old'] + return + if start < 0 or stop < 0: + self.position = change['old'] + return + if start >= self.frame_limits[seg_index] or start > self.frame_limits[seg_index]: + self.position = change['old'] + return + + self.segment_selector.value = seg_index + self.update_time(new_frame=start, update_slider=True, update_label=True) + delta_s = (stop - start) / self.sampling_frequency + self.window_sizer.value = delta_s + + self.observe(self.position_changed, names=['position'], type="change") + + def update_time(self, new_frame=None, new_time=None, update_slider=False, update_label=False): + if new_frame is None and new_time is None: + start_frame = self.slider.value + elif new_frame is None: + start_frame = int(new_time * self.sampling_frequency) + else: + start_frame = new_frame + delta_s = self.window_sizer.value + end_frame = start_frame + int(delta_s * self.sampling_frequency) + + # clip + start_frame = max(0, start_frame) + end_frame = min(self.frame_limits[self.segment_index], end_frame) + + + start_time = start_frame / self.sampling_frequency + + if update_label: + self.time_label.unobserve(self.time_label_changed, names='value', type="change") + self.time_label.value = f'{start_time}' + self.time_label.observe(self.time_label_changed, names='value', type="change") + + if update_slider: + self.slider.unobserve(self.slider_moved, names='value', type="change") + self.slider.value = start_frame + self.slider.observe(self.slider_moved, names='value', type="change") + + self.frame_range = (start_frame, end_frame) + + def time_label_changed(self, change=None): + try: + new_time = float(self.time_label.value) + except: + new_time = None + if new_time is not None: + self.update_time(new_time=new_time, update_slider=True) + + + def win_size_changed(self, change=None): + self.update_time() + + def slider_moved(self, change=None): + new_frame = self.slider.value + self.update_time(new_frame=new_frame, update_label=True) + + def move(self, sign): + value, units = self.move_size.value.split(' ') + value = int(value) + delta_s = (sign * np.timedelta64(value, units)) / np.timedelta64(1, 's') + delta_sample = int(delta_s * self.sampling_frequency) + + new_frame = self.frame_range[0] + delta_sample + self.slider.value = new_frame + + def move_left(self, change=None): + self.move(-1) + + def move_right(self, change=None): + self.move(+1) + + def segment_changed(self, change=None): + self.segment_index = self.segment_selector.value + + self.slider.unobserve(self.slider_moved, names='value', type="change") + # self.slider.value = 0 + self.slider.max = self.frame_limits[self.segment_index] + self.slider.observe(self.slider_moved, names='value', type="change") + + self.update_time(new_frame=0, update_slider=True, update_label=True) + + + +class ScaleWidget(W.VBox): + def __init__(self, **kwargs): + scale_label = W.Label("Scale", + layout=W.Layout(layout=W.Layout(width='95%'), + justify_content="center")) + + self.plus_selector = W.Button( + description="", + disabled=False, + button_style="", # 'success', 'info', 'warning', 'danger' or '' + tooltip="Increase scale", + icon="arrow-up", + # layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), + layout=W.Layout(width='95%'), + ) + + self.minus_selector = W.Button( + description="", + disabled=False, + button_style="", # 'success', 'info', 'warning', 'danger' or '' + tooltip="Decrease scale", + icon="arrow-down", + # layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), + layout=W.Layout(width='95%'), + ) + + # controller = {"plus": plus_selector, "minus": minus_selector} + # widget = W.VBox([scale_label, plus_selector, minus_selector]) + + + super(W.VBox, self).__init__(children=[scale_label, self.plus_selector, self.minus_selector], + # layout=W.Layout(align_items="center", width="100%", height="100%"), + **kwargs) From 389737efe1330f1f75afb73caedb41bb6bf84b4d Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 21 Sep 2023 20:58:38 +0200 Subject: [PATCH 03/26] wip refactor plot traces ipywidget --- src/spikeinterface/widgets/traces.py | 126 ++++++++++++++---- .../widgets/utils_ipywidgets.py | 62 ++++++--- 2 files changed, 145 insertions(+), 43 deletions(-) diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index c6e36387f8..efd32ffb24 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -279,9 +279,9 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): import ipywidgets.widgets as W from .utils_ipywidgets import ( check_ipywidget_backend, - make_timeseries_controller, + # make_timeseries_controller, make_channel_controller, - make_scale_controller, + # make_scale_controller, TimeSlider, ScaleWidget, @@ -315,21 +315,22 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): - ts_widget, ts_controller = make_timeseries_controller( - t_start, - t_stop, - data_plot["layer_keys"], - rec0.get_num_segments(), - data_plot["time_range"], - data_plot["mode"], - False, - width_cm, - ) + # ts_widget, ts_controller = make_timeseries_controller( + # t_start, + # t_stop, + # data_plot["layer_keys"], + # rec0.get_num_segments(), + # data_plot["time_range"], + # data_plot["mode"], + # False, + # width_cm, + # ) # some widgets self.time_slider = TimeSlider( durations=[rec0.get_duration(s) for s in range(rec0.get_num_segments())], sampling_frequency=rec0.sampling_frequency, + # layout=W.Layout(height="2cm"), ) self.layer_selector = W.Dropdown(description="layer", options=data_plot["layer_keys"], layout=W.Layout(width="5cm"),) @@ -338,22 +339,22 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.scaler = ScaleWidget() left_sidebar = W.VBox( children=[self.layer_selector, self.mode_selector, self.scaler], - layout=W.Layout(width="5cm"), + layout=W.Layout(width="3.5cm"), ) ch_widget, ch_controller = make_channel_controller(rec0, width_cm=ratios[2] * width_cm, height_cm=height_cm) - scale_widget, scale_controller = make_scale_controller(width_cm=ratios[0] * width_cm, height_cm=height_cm) + # scale_widget, scale_controller = make_scale_controller(width_cm=ratios[0] * width_cm, height_cm=height_cm) - self.controller = ts_controller - self.controller.update(ch_controller) - self.controller.update(scale_controller) + # self.controller = ts_controller + # self.controller.update(ch_controller) + # self.controller.update(scale_controller) self.recordings = data_plot["recordings"] self.return_scaled = data_plot["return_scaled"] self.list_traces = None - self.actual_segment_index = self.controller["segment_index"].value + # self.actual_segment_index = self.controller["segment_index"].value self.rec0 = self.recordings[self.data_plot["layer_keys"][0]] self.t_stops = [ @@ -361,11 +362,11 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): for seg_index in range(self.rec0.get_num_segments()) ] - for w in self.controller.values(): - if isinstance(w, widgets.Button): - w.on_click(self._update_ipywidget) - else: - w.observe(self._update_ipywidget) + # for w in self.controller.values(): + # if isinstance(w, widgets.Button): + # w.on_click(self._update_ipywidget) + # else: + # w.observe(self._update_ipywidget) self.widget = widgets.AppLayout( center=self.figure.canvas, @@ -379,12 +380,89 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): ) # a first update - self._update_ipywidget(None) + # self._update_ipywidget(None) + + self._retrieve_traces() + self._update_plot() + + # only layer selector and time change generate a new traces retrieve + self.time_slider.observe(self._retrieve_traces, names='value', type="change") + self.layer_selector.observe(self._retrieve_traces, names='value', type="change") + # other widgets only refresh + self.scaler.observe(self._update_plot, names='value', type="change") + self.mode_selector.observe(self._update_plot, names='value', type="change") + if backend_kwargs["display"]: # self.check_backend() display(self.widget) + + + def _retrieve_traces(self, change=None): + # done when: + # * time or window is changes + # * layer is changed + + # TODO connect with channel selector + channel_ids = self.rec0.channel_ids + + # all_channel_ids = self.recordings[list(self.recordings.keys())[0]].channel_ids + # if self.data_plot["order"] is not None: + # all_channel_ids = all_channel_ids[self.data_plot["order"]] + # channel_ids = all_channel_ids[channel_indices] + if self.data_plot["order_channel_by_depth"]: + order, _ = order_channels_by_depth(self.rec0, channel_ids) + else: + order = None + + start_frame, end_frame, segment_index = self.time_slider.value + time_range = np.array([start_frame, end_frame]) / self.rec0.sampling_frequency + + times, list_traces, frame_range, channel_ids = _get_trace_list( + self.recordings, channel_ids, time_range, segment_index, order, self.return_scaled + ) + self.list_traces = list_traces + + self._update_plot() + + def _update_plot(self, change=None): + # done when: + # * time or window is changed (after _retrive_traces) + # * layer is changed (after _retrive_traces) + #  * scale is change + # * mode is change + + data_plot = self.next_data_plot + + # matplotlib next_data_plot dict update at each call + data_plot["mode"] = self.mode_selector.value + # data_plot["frame_range"] = frame_range + # data_plot["time_range"] = time_range + data_plot["with_colorbar"] = False + # data_plot["recordings"] = recordings + # data_plot["layer_keys"] = layer_keys + # data_plot["list_traces"] = list_traces_plot + # data_plot["times"] = times + # data_plot["clims"] = clims + # data_plot["channel_ids"] = channel_ids + + list_traces = [traces * self.scaler.value for traces in self.list_traces] + data_plot["list_traces"] = list_traces + + backend_kwargs = {} + backend_kwargs["ax"] = self.ax + + self.ax.clear() + self.plot_matplotlib(data_plot, **backend_kwargs) + + fig = self.ax.figure + fig.canvas.draw() + fig.canvas.flush_events() + + + + def _update_ipywidget(self, change): import ipywidgets.widgets as widgets diff --git a/src/spikeinterface/widgets/utils_ipywidgets.py b/src/spikeinterface/widgets/utils_ipywidgets.py index 674a2d2cc7..ad0ead7bc0 100644 --- a/src/spikeinterface/widgets/utils_ipywidgets.py +++ b/src/spikeinterface/widgets/utils_ipywidgets.py @@ -109,7 +109,7 @@ def make_scale_controller(width_cm, height_cm): class TimeSlider(W.HBox): - position = traitlets.Tuple(traitlets.Int(), traitlets.Int(), traitlets.Int()) + value = traitlets.Tuple(traitlets.Int(), traitlets.Int(), traitlets.Int()) def __init__(self, durations, sampling_frequency, time_range=(0, 1.), **kwargs): @@ -123,10 +123,10 @@ def __init__(self, durations, sampling_frequency, time_range=(0, 1.), **kwargs): self.frame_range = (start_frame, end_frame) self.segment_index = 0 - self.position = (start_frame, end_frame, self.segment_index) + self.value = (start_frame, end_frame, self.segment_index) - layout = W.Layout(align_items="center", width="1.5cm", height="100%") + layout = W.Layout(align_items="center", width="2cm", hight="1.5cm") but_left = W.Button(description='', disabled=False, button_style='', icon='arrow-left', layout=layout) but_right = W.Button(description='', disabled=False, button_style='', icon='arrow-right', layout=layout) @@ -176,21 +176,21 @@ def __init__(self, durations, sampling_frequency, time_range=(0, 1.), **kwargs): layout=W.Layout(align_items="center", width="100%", height="100%"), **kwargs) - self.observe(self.position_changed, names=['position'], type="change") + self.observe(self.value_changed, names=['value'], type="change") - def position_changed(self, change=None): + def value_changed(self, change=None): - self.unobserve(self.position_changed, names=['position'], type="change") + self.unobserve(self.value_changed, names=['value'], type="change") - start, stop, seg_index = self.position + start, stop, seg_index = self.value if seg_index < 0 or seg_index >= self.num_segments: - self.position = change['old'] + self.value = change['old'] return if start < 0 or stop < 0: - self.position = change['old'] + self.value = change['old'] return if start >= self.frame_limits[seg_index] or start > self.frame_limits[seg_index]: - self.position = change['old'] + self.value = change['old'] return self.segment_selector.value = seg_index @@ -198,7 +198,7 @@ def position_changed(self, change=None): delta_s = (stop - start) / self.sampling_frequency self.window_sizer.value = delta_s - self.observe(self.position_changed, names=['position'], type="change") + self.observe(self.value_changed, names=['value'], type="change") def update_time(self, new_frame=None, new_time=None, update_slider=False, update_label=False): if new_frame is None and new_time is None: @@ -228,6 +228,7 @@ def update_time(self, new_frame=None, new_time=None, update_slider=False, update self.slider.observe(self.slider_moved, names='value', type="change") self.frame_range = (start_frame, end_frame) + self.value = (start_frame, end_frame, self.segment_index) def time_label_changed(self, change=None): try: @@ -273,8 +274,14 @@ def segment_changed(self, change=None): class ScaleWidget(W.VBox): - def __init__(self, **kwargs): - scale_label = W.Label("Scale", + value = traitlets.Float() + + def __init__(self, value=1., factor=1.2, **kwargs): + + assert factor > 1. + self.factor = factor + + self.scale_label = W.Label("Scale", layout=W.Layout(layout=W.Layout(width='95%'), justify_content="center")) @@ -285,7 +292,7 @@ def __init__(self, **kwargs): tooltip="Increase scale", icon="arrow-up", # layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), - layout=W.Layout(width='95%'), + layout=W.Layout(width='60%', align_self='center'), ) self.minus_selector = W.Button( @@ -295,13 +302,30 @@ def __init__(self, **kwargs): tooltip="Decrease scale", icon="arrow-down", # layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), - layout=W.Layout(width='95%'), + layout=W.Layout(width='60%', align_self='center'), ) - # controller = {"plus": plus_selector, "minus": minus_selector} - # widget = W.VBox([scale_label, plus_selector, minus_selector]) + self.plus_selector.on_click(self.plus_clicked) + self.minus_selector.on_click(self.minus_clicked) - - super(W.VBox, self).__init__(children=[scale_label, self.plus_selector, self.minus_selector], + self.value = 1. + super(W.VBox, self).__init__(children=[self.plus_selector, self.scale_label, self.minus_selector], # layout=W.Layout(align_items="center", width="100%", height="100%"), **kwargs) + + self.update_label() + self.observe(self.value_changed, names=['value'], type="change") + + def update_label(self): + self.scale_label.value = f"Scale: {self.value:0.2f}" + + + def plus_clicked(self, change=None): + self.value = self.value * self.factor + + def minus_clicked(self, change=None): + self.value = self.value / self.factor + + + def value_changed(self, change=None): + self.update_label() From e5995f2aa6445fd878e1c0881f11299f8ae22a2d Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 21 Sep 2023 22:59:53 +0200 Subject: [PATCH 04/26] ipywidget backend refactor wip --- src/spikeinterface/widgets/traces.py | 298 +++++------------- .../widgets/utils_ipywidgets.py | 175 ++++++---- 2 files changed, 190 insertions(+), 283 deletions(-) diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index efd32ffb24..d107c5cb23 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -280,23 +280,23 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): from .utils_ipywidgets import ( check_ipywidget_backend, # make_timeseries_controller, - make_channel_controller, + # make_channel_controller, # make_scale_controller, - TimeSlider, + ChannelSelector, ScaleWidget, - ) check_ipywidget_backend() self.next_data_plot = data_plot.copy() - self.next_data_plot["add_legend"] = False + - recordings = data_plot["recordings"] + self.recordings = data_plot["recordings"] # first layer - rec0 = recordings[data_plot["layer_keys"][0]] + # rec0 = recordings[data_plot["layer_keys"][0]] + rec0 = self.rec0 = self.recordings[self.data_plot["layer_keys"][0]] cm = 1 / 2.54 @@ -310,107 +310,92 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.figure, self.ax = plt.subplots(figsize=(0.9 * ratios[1] * width_cm * cm, height_cm * cm)) plt.show() - t_start = 0.0 - t_stop = rec0.get_num_samples(segment_index=0) / rec0.get_sampling_frequency() - - - - # ts_widget, ts_controller = make_timeseries_controller( - # t_start, - # t_stop, - # data_plot["layer_keys"], - # rec0.get_num_segments(), - # data_plot["time_range"], - # data_plot["mode"], - # False, - # width_cm, - # ) - # some widgets self.time_slider = TimeSlider( durations=[rec0.get_duration(s) for s in range(rec0.get_num_segments())], sampling_frequency=rec0.sampling_frequency, # layout=W.Layout(height="2cm"), ) - self.layer_selector = W.Dropdown(description="layer", options=data_plot["layer_keys"], - layout=W.Layout(width="5cm"),) - self.mode_selector = W.Dropdown(options=["line", "map"], description="mode", value=data_plot["mode"], - layout=W.Layout(width="5cm"),) + + start_frame = int(data_plot["time_range"][0] * rec0.sampling_frequency) + end_frame = int(data_plot["time_range"][1] * rec0.sampling_frequency) + + self.time_slider.value = start_frame, end_frame, data_plot["segment_index"] + + _layer_keys = data_plot["layer_keys"] + if len(_layer_keys) > 1: + _layer_keys = ['ALL'] + _layer_keys + self.layer_selector = W.Dropdown(options=_layer_keys, + layout=W.Layout(width="95%"), + ) + self.mode_selector = W.Dropdown(options=["line", "map"], value=data_plot["mode"], + # layout=W.Layout(width="5cm"), + layout=W.Layout(width="95%"), + ) self.scaler = ScaleWidget() + self.channel_selector = ChannelSelector(self.rec0.channel_ids) + left_sidebar = W.VBox( - children=[self.layer_selector, self.mode_selector, self.scaler], + children=[ + W.Label(value="layer"), + self.layer_selector, + W.Label(value="mode"), + self.mode_selector, + self.scaler, + # self.channel_selector, + ], layout=W.Layout(width="3.5cm"), + align_items='center', ) - - ch_widget, ch_controller = make_channel_controller(rec0, width_cm=ratios[2] * width_cm, height_cm=height_cm) - - # scale_widget, scale_controller = make_scale_controller(width_cm=ratios[0] * width_cm, height_cm=height_cm) - - # self.controller = ts_controller - # self.controller.update(ch_controller) - # self.controller.update(scale_controller) - - self.recordings = data_plot["recordings"] self.return_scaled = data_plot["return_scaled"] - self.list_traces = None - # self.actual_segment_index = self.controller["segment_index"].value - - self.rec0 = self.recordings[self.data_plot["layer_keys"][0]] - self.t_stops = [ - self.rec0.get_num_samples(segment_index=seg_index) / self.rec0.get_sampling_frequency() - for seg_index in range(self.rec0.get_num_segments()) - ] - - # for w in self.controller.values(): - # if isinstance(w, widgets.Button): - # w.on_click(self._update_ipywidget) - # else: - # w.observe(self._update_ipywidget) self.widget = widgets.AppLayout( center=self.figure.canvas, - # footer=ts_widget, footer=self.time_slider, - # left_sidebar=scale_widget, left_sidebar = left_sidebar, - right_sidebar=ch_widget, + right_sidebar=self.channel_selector, pane_heights=[0, 6, 1], pane_widths=ratios, ) # a first update - # self._update_ipywidget(None) - self._retrieve_traces() self._update_plot() - # only layer selector and time change generate a new traces retrieve + # callbacks: + # some widgets generate a full retrieve + refresh self.time_slider.observe(self._retrieve_traces, names='value', type="change") self.layer_selector.observe(self._retrieve_traces, names='value', type="change") + self.channel_selector.observe(self._retrieve_traces, names='value', type="change") # other widgets only refresh self.scaler.observe(self._update_plot, names='value', type="change") - self.mode_selector.observe(self._update_plot, names='value', type="change") + # map is a special case because needs to check layer also + self.mode_selector.observe(self._mode_changed, names='value', type="change") - if backend_kwargs["display"]: # self.check_backend() display(self.widget) - + def _get_layers(self): + layer = self.layer_selector.value + if layer == 'ALL': + layer_keys = self.data_plot["layer_keys"] + else: + layer_keys = [layer] + if self.mode_selector.value == "map": + layer_keys = layer_keys[:1] + return layer_keys + + def _mode_changed(self, change=None): + if self.mode_selector.value == "map" and self.layer_selector.value == "ALL": + self.layer_selector.value = self.data_plot["layer_keys"][0] + else: + self._update_plot() def _retrieve_traces(self, change=None): - # done when: - # * time or window is changes - # * layer is changed + channel_ids = np.array(self.channel_selector.value) - # TODO connect with channel selector - channel_ids = self.rec0.channel_ids - - # all_channel_ids = self.recordings[list(self.recordings.keys())[0]].channel_ids - # if self.data_plot["order"] is not None: - # all_channel_ids = all_channel_ids[self.data_plot["order"]] - # channel_ids = all_channel_ids[channel_indices] if self.data_plot["order_channel_by_depth"]: order, _ = order_channels_by_depth(self.rec0, channel_ids) else: @@ -419,176 +404,61 @@ def _retrieve_traces(self, change=None): start_frame, end_frame, segment_index = self.time_slider.value time_range = np.array([start_frame, end_frame]) / self.rec0.sampling_frequency + self._selected_recordings = {k: self.recordings[k] for k in self._get_layers()} times, list_traces, frame_range, channel_ids = _get_trace_list( - self.recordings, channel_ids, time_range, segment_index, order, self.return_scaled + self._selected_recordings, channel_ids, time_range, segment_index, order, self.return_scaled ) - self.list_traces = list_traces + + self._channel_ids = channel_ids + self._list_traces = list_traces + self._times = times + self._time_range = time_range + self._frame_range = (start_frame, end_frame) + self._segment_index = segment_index self._update_plot() def _update_plot(self, change=None): - # done when: - # * time or window is changed (after _retrive_traces) - # * layer is changed (after _retrive_traces) - #  * scale is change - # * mode is change - data_plot = self.next_data_plot # matplotlib next_data_plot dict update at each call - data_plot["mode"] = self.mode_selector.value - # data_plot["frame_range"] = frame_range - # data_plot["time_range"] = time_range - data_plot["with_colorbar"] = False - # data_plot["recordings"] = recordings - # data_plot["layer_keys"] = layer_keys - # data_plot["list_traces"] = list_traces_plot - # data_plot["times"] = times - # data_plot["clims"] = clims - # data_plot["channel_ids"] = channel_ids - - list_traces = [traces * self.scaler.value for traces in self.list_traces] - data_plot["list_traces"] = list_traces - - backend_kwargs = {} - backend_kwargs["ax"] = self.ax - - self.ax.clear() - self.plot_matplotlib(data_plot, **backend_kwargs) - - fig = self.ax.figure - fig.canvas.draw() - fig.canvas.flush_events() - - - - - def _update_ipywidget(self, change): - import ipywidgets.widgets as widgets - - # if changing the layer_key, no need to retrieve and process traces - retrieve_traces = True - scale_up = False - scale_down = False - if change is not None: - for cname, c in self.controller.items(): - if isinstance(change, dict): - if change["owner"] is c and cname == "layer_key": - retrieve_traces = False - elif isinstance(change, widgets.Button): - if change is c and cname == "plus": - scale_up = True - if change is c and cname == "minus": - scale_down = True - - t_start = self.controller["t_start"].value - window = self.controller["window"].value - layer_key = self.controller["layer_key"].value - segment_index = self.controller["segment_index"].value - mode = self.controller["mode"].value - chan_start, chan_stop = self.controller["channel_inds"].value - - if mode == "line": - self.controller["all_layers"].layout.visibility = "visible" - all_layers = self.controller["all_layers"].value - elif mode == "map": - self.controller["all_layers"].layout.visibility = "hidden" - all_layers = False - - if all_layers: - self.controller["layer_key"].layout.visibility = "hidden" - else: - self.controller["layer_key"].layout.visibility = "visible" - - if chan_start == chan_stop: - chan_stop += 1 - channel_indices = np.arange(chan_start, chan_stop) - - t_stop = self.t_stops[segment_index] - if self.actual_segment_index != segment_index: - # change time_slider limits - self.controller["t_start"].max = t_stop - self.actual_segment_index = segment_index - - # protect limits - if t_start >= t_stop - window: - t_start = t_stop - window - - time_range = np.array([t_start, t_start + window]) - data_plot = self.next_data_plot + mode = self.mode_selector.value + layer_keys = self._get_layers() - if retrieve_traces: - all_channel_ids = self.recordings[list(self.recordings.keys())[0]].channel_ids - if self.data_plot["order"] is not None: - all_channel_ids = all_channel_ids[self.data_plot["order"]] - channel_ids = all_channel_ids[channel_indices] - if self.data_plot["order_channel_by_depth"]: - order, _ = order_channels_by_depth(self.rec0, channel_ids) - else: - order = None - times, list_traces, frame_range, channel_ids = _get_trace_list( - self.recordings, channel_ids, time_range, segment_index, order, self.return_scaled - ) - self.list_traces = list_traces - else: - times = data_plot["times"] - list_traces = data_plot["list_traces"] - frame_range = data_plot["frame_range"] - channel_ids = data_plot["channel_ids"] - - if all_layers: - layer_keys = self.data_plot["layer_keys"] - recordings = self.recordings - list_traces_plot = self.list_traces - else: - layer_keys = [layer_key] - recordings = {layer_key: self.recordings[layer_key]} - list_traces_plot = [self.list_traces[list(self.recordings.keys()).index(layer_key)]] - - if scale_up: - if mode == "line": - data_plot["vspacing"] *= 0.8 - elif mode == "map": - data_plot["clims"] = { - layer: (1.2 * val[0], 1.2 * val[1]) for layer, val in self.data_plot["clims"].items() - } - if scale_down: - if mode == "line": - data_plot["vspacing"] *= 1.2 - elif mode == "map": - data_plot["clims"] = { - layer: (0.8 * val[0], 0.8 * val[1]) for layer, val in self.data_plot["clims"].items() - } - - self.next_data_plot["vspacing"] = data_plot["vspacing"] - self.next_data_plot["clims"] = data_plot["clims"] + data_plot["mode"] = mode + data_plot["frame_range"] = self._frame_range + data_plot["time_range"] = self._time_range + data_plot["with_colorbar"] = False + data_plot["recordings"] = self._selected_recordings + data_plot["add_legend"] = False if mode == "line": clims = None elif mode == "map": - clims = {layer_key: self.data_plot["clims"][layer_key]} + clims = {k: self.data_plot["clims"][k] for k in layer_keys} - # matplotlib next_data_plot dict update at each call - data_plot["mode"] = mode - data_plot["frame_range"] = frame_range - data_plot["time_range"] = time_range - data_plot["with_colorbar"] = False - data_plot["recordings"] = recordings - data_plot["layer_keys"] = layer_keys - data_plot["list_traces"] = list_traces_plot - data_plot["times"] = times data_plot["clims"] = clims - data_plot["channel_ids"] = channel_ids + data_plot["channel_ids"] = self._channel_ids + + data_plot["layer_keys"] = layer_keys + data_plot["colors"] = {k:self.data_plot["colors"][k] for k in layer_keys} + + list_traces = [traces * self.scaler.value for traces in self._list_traces] + data_plot["list_traces"] = list_traces + data_plot["times"] = self._times backend_kwargs = {} backend_kwargs["ax"] = self.ax + self.ax.clear() self.plot_matplotlib(data_plot, **backend_kwargs) + self.ax.set_title("") fig = self.ax.figure fig.canvas.draw() fig.canvas.flush_events() + def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv from .utils_sortingview import handle_display_and_url diff --git a/src/spikeinterface/widgets/utils_ipywidgets.py b/src/spikeinterface/widgets/utils_ipywidgets.py index ad0ead7bc0..ab2b51a7bb 100644 --- a/src/spikeinterface/widgets/utils_ipywidgets.py +++ b/src/spikeinterface/widgets/utils_ipywidgets.py @@ -11,35 +11,35 @@ def check_ipywidget_backend(): assert "ipympl" in mpl_backend, "To use the 'ipywidgets' backend, you have to set %matplotlib widget" -def make_timeseries_controller(t_start, t_stop, layer_keys, num_segments, time_range, mode, all_layers, width_cm): - time_slider = W.FloatSlider( - orientation="horizontal", - description="time:", - value=time_range[0], - min=t_start, - max=t_stop, - continuous_update=False, - layout=W.Layout(width=f"{width_cm}cm"), - ) - layer_selector = W.Dropdown(description="layer", options=layer_keys) - segment_selector = W.Dropdown(description="segment", options=list(range(num_segments))) - window_sizer = W.BoundedFloatText(value=np.diff(time_range)[0], step=0.1, min=0.005, description="win (s)") - mode_selector = W.Dropdown(options=["line", "map"], description="mode", value=mode) - all_layers = W.Checkbox(description="plot all layers", value=all_layers) - - controller = { - "layer_key": layer_selector, - "segment_index": segment_selector, - "window": window_sizer, - "t_start": time_slider, - "mode": mode_selector, - "all_layers": all_layers, - } - widget = W.VBox( - [time_slider, W.HBox([all_layers, layer_selector, segment_selector, window_sizer, mode_selector])] - ) - - return widget, controller +# def make_timeseries_controller(t_start, t_stop, layer_keys, num_segments, time_range, mode, all_layers, width_cm): +# time_slider = W.FloatSlider( +# orientation="horizontal", +# description="time:", +# value=time_range[0], +# min=t_start, +# max=t_stop, +# continuous_update=False, +# layout=W.Layout(width=f"{width_cm}cm"), +# ) +# layer_selector = W.Dropdown(description="layer", options=layer_keys) +# segment_selector = W.Dropdown(description="segment", options=list(range(num_segments))) +# window_sizer = W.BoundedFloatText(value=np.diff(time_range)[0], step=0.1, min=0.005, description="win (s)") +# mode_selector = W.Dropdown(options=["line", "map"], description="mode", value=mode) +# all_layers = W.Checkbox(description="plot all layers", value=all_layers) + +# controller = { +# "layer_key": layer_selector, +# "segment_index": segment_selector, +# "window": window_sizer, +# "t_start": time_slider, +# "mode": mode_selector, +# "all_layers": all_layers, +# } +# widget = W.VBox( +# [time_slider, W.HBox([all_layers, layer_selector, segment_selector, window_sizer, mode_selector])] +# ) + +# return widget, controller def make_unit_controller(unit_ids, all_unit_ids, width_cm, height_cm): @@ -58,52 +58,52 @@ def make_unit_controller(unit_ids, all_unit_ids, width_cm, height_cm): return widget, controller -def make_channel_controller(recording, width_cm, height_cm): - channel_label = W.Label("channel indices:", layout=W.Layout(justify_content="center")) - channel_selector = W.IntRangeSlider( - value=[0, recording.get_num_channels()], - min=0, - max=recording.get_num_channels(), - step=1, - disabled=False, - continuous_update=False, - orientation="vertical", - readout=True, - readout_format="d", - layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{height_cm}cm"), - ) +# def make_channel_controller(recording, width_cm, height_cm): +# channel_label = W.Label("channel indices:", layout=W.Layout(justify_content="center")) +# channel_selector = W.IntRangeSlider( +# value=[0, recording.get_num_channels()], +# min=0, +# max=recording.get_num_channels(), +# step=1, +# disabled=False, +# continuous_update=False, +# orientation="vertical", +# readout=True, +# readout_format="d", +# layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{height_cm}cm"), +# ) - controller = {"channel_inds": channel_selector} - widget = W.VBox([channel_label, channel_selector]) +# controller = {"channel_inds": channel_selector} +# widget = W.VBox([channel_label, channel_selector]) - return widget, controller +# return widget, controller -def make_scale_controller(width_cm, height_cm): - scale_label = W.Label("Scale", layout=W.Layout(justify_content="center")) +# def make_scale_controller(width_cm, height_cm): +# scale_label = W.Label("Scale", layout=W.Layout(justify_content="center")) - plus_selector = W.Button( - description="", - disabled=False, - button_style="", # 'success', 'info', 'warning', 'danger' or '' - tooltip="Increase scale", - icon="arrow-up", - layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), - ) +# plus_selector = W.Button( +# description="", +# disabled=False, +# button_style="", # 'success', 'info', 'warning', 'danger' or '' +# tooltip="Increase scale", +# icon="arrow-up", +# layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), +# ) - minus_selector = W.Button( - description="", - disabled=False, - button_style="", # 'success', 'info', 'warning', 'danger' or '' - tooltip="Decrease scale", - icon="arrow-down", - layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), - ) +# minus_selector = W.Button( +# description="", +# disabled=False, +# button_style="", # 'success', 'info', 'warning', 'danger' or '' +# tooltip="Decrease scale", +# icon="arrow-down", +# layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), +# ) - controller = {"plus": plus_selector, "minus": minus_selector} - widget = W.VBox([scale_label, plus_selector, minus_selector]) +# controller = {"plus": plus_selector, "minus": minus_selector} +# widget = W.VBox([scale_label, plus_selector, minus_selector]) - return widget, controller +# return widget, controller @@ -126,7 +126,7 @@ def __init__(self, durations, sampling_frequency, time_range=(0, 1.), **kwargs): self.value = (start_frame, end_frame, self.segment_index) - layout = W.Layout(align_items="center", width="2cm", hight="1.5cm") + layout = W.Layout(align_items="center", width="2.5cm", height="1.cm") but_left = W.Button(description='', disabled=False, button_style='', icon='arrow-left', layout=layout) but_right = W.Button(description='', disabled=False, button_style='', icon='arrow-right', layout=layout) @@ -141,7 +141,7 @@ def __init__(self, durations, sampling_frequency, time_range=(0, 1.), **kwargs): # DatetimePicker is only for ipywidget v8 (which is not working in vscode 2023-03) self.time_label = W.Text(value=f'{time_range[0]}',description='', - disabled=False, layout=W.Layout(width='5.5cm')) + disabled=False, layout=W.Layout(width='2.5cm')) self.time_label.observe(self.time_label_changed, names='value', type="change") @@ -271,6 +271,43 @@ def segment_changed(self, change=None): self.update_time(new_frame=0, update_slider=True, update_label=True) +class ChannelSelector(W.VBox): + value = traitlets.List() + + def __init__(self, channel_ids, **kwargs): + self.channel_ids = list(channel_ids) + self.value = self.channel_ids + + channel_label = W.Label("Channels", layout=W.Layout(justify_content="center")) + n = len(channel_ids) + self.slider = W.IntRangeSlider( + value=[0, n], + min=0, + max=n, + step=1, + disabled=False, + continuous_update=False, + orientation="vertical", + readout=True, + readout_format="d", + # layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{height_cm}cm"), + layout=W.Layout(height="100%"), + ) + + + + super(W.VBox, self).__init__(children=[channel_label, self.slider], + layout=W.Layout(align_items="center"), + # layout=W.Layout(align_items="center", width="100%", height="100%"), + **kwargs) + self.slider.observe(self.on_slider_changed, names=['value'], type="change") + # self.update_label() + # self.observe(self.value_changed, names=['value'], type="change") + + def on_slider_changed(self, change=None): + i0, i1 = self.slider.value + self.value = self.channel_ids[i0:i1] + class ScaleWidget(W.VBox): From 7b92c2153d4fad412823100fd77079e3cf286138 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 22 Sep 2023 08:06:37 +0200 Subject: [PATCH 05/26] improve channel selector --- .../widgets/utils_ipywidgets.py | 38 +++++++++++++++++-- 1 file changed, 35 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/widgets/utils_ipywidgets.py b/src/spikeinterface/widgets/utils_ipywidgets.py index ab2b51a7bb..705dd09f23 100644 --- a/src/spikeinterface/widgets/utils_ipywidgets.py +++ b/src/spikeinterface/widgets/utils_ipywidgets.py @@ -294,20 +294,52 @@ def __init__(self, channel_ids, **kwargs): layout=W.Layout(height="100%"), ) + # first channel are bottom: need reverse + self.selector = W.SelectMultiple( + options=self.channel_ids[::-1], + value=self.channel_ids[::-1], + disabled=False, + # layout=W.Layout(width=f"{width_cm}cm", height=f"{height_cm}cm"), + layout=W.Layout(height="100%", width="2cm"), + ) + hbox = W.HBox(children=[self.slider, self.selector]) - - super(W.VBox, self).__init__(children=[channel_label, self.slider], + super(W.VBox, self).__init__(children=[channel_label, hbox], layout=W.Layout(align_items="center"), # layout=W.Layout(align_items="center", width="100%", height="100%"), **kwargs) self.slider.observe(self.on_slider_changed, names=['value'], type="change") - # self.update_label() + self.selector.observe(self.on_selector_changed, names=['value'], type="change") + + # TODO external value change # self.observe(self.value_changed, names=['value'], type="change") def on_slider_changed(self, change=None): i0, i1 = self.slider.value + + self.selector.unobserve(self.on_selector_changed, names=['value'], type="change") + self.selector.value = self.channel_ids[i0:i1][::-1] + self.selector.observe(self.on_selector_changed, names=['value'], type="change") + self.value = self.channel_ids[i0:i1] + def on_selector_changed(self, change=None): + channel_ids = self.selector.value + channel_ids = channel_ids[::-1] + + if len(channel_ids) > 0: + self.slider.unobserve(self.on_slider_changed, names=['value'], type="change") + i0 = self.channel_ids.index(channel_ids[0]) + i1 = self.channel_ids.index(channel_ids[-1]) + 1 + self.slider.value = (i0, i1) + self.slider.observe(self.on_slider_changed, names=['value'], type="change") + + self.value = channel_ids + + + + + class ScaleWidget(W.VBox): From c46a7cba4b1e937d40050d0061017256ab5dade3 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 22 Sep 2023 10:31:05 +0200 Subject: [PATCH 06/26] Allow to restrict sparsity --- .../postprocessing/amplitude_scalings.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 4dab68fdf8..3eac333781 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -68,7 +68,6 @@ def _run(self, **job_kwargs): delta_collision_samples = int(delta_collision_ms / 1000 * we.sampling_frequency) return_scaled = we._params["return_scaled"] - unit_ids = we.unit_ids if ms_before is not None: assert ( @@ -82,9 +81,16 @@ def _run(self, **job_kwargs): cut_out_before = int(ms_before / 1000 * we.sampling_frequency) if ms_before is not None else nbefore cut_out_after = int(ms_after / 1000 * we.sampling_frequency) if ms_after is not None else nafter - if we.is_sparse(): + if we.is_sparse() and self._params["sparsity"] is None: sparsity = we.sparsity - elif self._params["sparsity"] is not None: + elif we.is_sparse() and self._params["sparsity"] is not None: + sparsity = self._params["sparsity"] + # assert provided sparsity is sparser than the one in the waveform extractor + waveform_sparsity = we.sparsity + assert np.all( + np.sum(waveform_sparsity.mask, 1) - np.sum(sparsity.mask, 1) > 0 + ), "The provided sparsity needs to be sparser than the one in the waveform extractor!" + elif not we.is_sparse() and self._params["sparsity"] is not None: sparsity = self._params["sparsity"] else: if self._params["max_dense_channels"] is not None: @@ -362,7 +368,7 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) template = template[cut_out_before - sample_index :] elif sample_index + cut_out_after > end_frame + right: local_waveform = traces_with_margin[cut_out_start:, sparse_indices] - template = template[: -(sample_index + cut_out_after - end_frame)] + template = template[: -(sample_index + cut_out_after - end_frame - right)] else: local_waveform = traces_with_margin[cut_out_start:cut_out_end, sparse_indices] assert template.shape == local_waveform.shape From 2e305586d5b39bb8bfa89280057579a97726e93a Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 22 Sep 2023 11:09:05 +0200 Subject: [PATCH 07/26] ipywidgets backend start UnitCOntroller --- src/spikeinterface/widgets/amplitudes.py | 69 ++++++++++--------- .../widgets/utils_ipywidgets.py | 39 +++++++++-- 2 files changed, 71 insertions(+), 37 deletions(-) diff --git a/src/spikeinterface/widgets/amplitudes.py b/src/spikeinterface/widgets/amplitudes.py index 7ef6e0ff61..b60de98cb0 100644 --- a/src/spikeinterface/widgets/amplitudes.py +++ b/src/spikeinterface/widgets/amplitudes.py @@ -171,9 +171,10 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt - import ipywidgets.widgets as widgets + # import ipywidgets.widgets as widgets + import ipywidgets.widgets as W from IPython.display import display - from .utils_ipywidgets import check_ipywidget_backend, make_unit_controller + from .utils_ipywidgets import check_ipywidget_backend, make_unit_controller, UnitSelector check_ipywidget_backend() @@ -188,60 +189,62 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): ratios = [0.15, 0.85] with plt.ioff(): - output = widgets.Output() + output = W.Output() with output: self.figure = plt.figure(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) plt.show() - data_plot["unit_ids"] = data_plot["unit_ids"][:1] - unit_widget, unit_controller = make_unit_controller( - data_plot["unit_ids"], we.unit_ids, ratios[0] * width_cm, height_cm - ) + self.unit_selector = UnitSelector(we.unit_ids) + self.unit_selector.value = list(we.unit_ids)[:1] - plot_histograms = widgets.Checkbox( + self.checkbox_histograms = W.Checkbox( value=data_plot["plot_histograms"], - description="plot histograms", - disabled=False, + description="hist", + # disabled=False, ) - footer = plot_histograms - - self.controller = {"plot_histograms": plot_histograms} - self.controller.update(unit_controller) - - for w in self.controller.values(): - w.observe(self._update_ipywidget) + left_sidebar = W.VBox( + children=[ + self.unit_selector, + self.checkbox_histograms, + ], + layout = W.Layout(align_items="center", width="4cm", height="100%"), + ) - self.widget = widgets.AppLayout( + self.widget = W.AppLayout( center=self.figure.canvas, - left_sidebar=unit_widget, + left_sidebar=left_sidebar, pane_widths=ratios + [0], - footer=footer, ) # a first update - self._update_ipywidget(None) + self._full_update_plot() + + self.unit_selector.observe(self._update_plot, names='value', type="change") + self.checkbox_histograms.observe(self._full_update_plot, names='value', type="change") if backend_kwargs["display"]: display(self.widget) - def _update_ipywidget(self, change): + def _full_update_plot(self, change=None): self.figure.clear() + data_plot = self.next_data_plot + data_plot["unit_ids"] = self.unit_selector.value + data_plot["plot_histograms"] = self.checkbox_histograms.value + + backend_kwargs = dict(figure=self.figure, axes=None, ax=None) + self.plot_matplotlib(data_plot, **backend_kwargs) + self._update_plot() - unit_ids = self.controller["unit_ids"].value - plot_histograms = self.controller["plot_histograms"].value + def _update_plot(self, change=None): + for ax in self.axes.flatten(): + ax.clear() - # matplotlib next_data_plot dict update at each call data_plot = self.next_data_plot - data_plot["unit_ids"] = unit_ids - data_plot["plot_histograms"] = plot_histograms - - backend_kwargs = {} - # backend_kwargs["figure"] = self.fig - backend_kwargs["figure"] = self.figure - backend_kwargs["axes"] = None - backend_kwargs["ax"] = None + data_plot["unit_ids"] = self.unit_selector.value + data_plot["plot_histograms"] = self.checkbox_histograms.value + backend_kwargs = dict(figure=None, axes=self.axes, ax=None) self.plot_matplotlib(data_plot, **backend_kwargs) self.figure.canvas.draw() diff --git a/src/spikeinterface/widgets/utils_ipywidgets.py b/src/spikeinterface/widgets/utils_ipywidgets.py index 705dd09f23..d2c41f234a 100644 --- a/src/spikeinterface/widgets/utils_ipywidgets.py +++ b/src/spikeinterface/widgets/utils_ipywidgets.py @@ -338,10 +338,6 @@ def on_selector_changed(self, change=None): - - - - class ScaleWidget(W.VBox): value = traitlets.Float() @@ -398,3 +394,38 @@ def minus_clicked(self, change=None): def value_changed(self, change=None): self.update_label() + + +class UnitSelector(W.VBox): + value = traitlets.List() + + def __init__(self, unit_ids, **kwargs): + self.unit_ids = list(unit_ids) + self.value = self.unit_ids + + label = W.Label("Units", layout=W.Layout(justify_content="center")) + + self.selector = W.SelectMultiple( + options=self.unit_ids, + value=self.unit_ids, + disabled=False, + layout=W.Layout(height="100%", width="2cm"), + ) + + super(W.VBox, self).__init__(children=[label, self.selector], + layout=W.Layout(align_items="center"), + **kwargs) + + self.selector.observe(self.on_selector_changed, names=['value'], type="change") + + self.observe(self.value_changed, names=['value'], type="change") + + def on_selector_changed(self, change=None): + unit_ids = self.selector.value + self.value = unit_ids + + def value_changed(self, change=None): + self.selector.unobserve(self.on_selector_changed, names=['value'], type="change") + self.selector.value = change['new'] + self.selector.observe(self.on_selector_changed, names=['value'], type="change") + From 4e31329d9aed376ecc41c4238a2f4836f94054ea Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 22 Sep 2023 11:37:18 +0200 Subject: [PATCH 08/26] Add spikes on border when generating sorting, PCA sparse return fixes --- src/spikeinterface/core/generate.py | 28 +++++++++++++++++ .../core/tests/test_generate.py | 30 +++++++++++++++++-- .../postprocessing/amplitude_scalings.py | 12 ++++---- .../postprocessing/principal_component.py | 15 ++++++++-- .../tests/common_extension_tests.py | 26 ++++++++++++++-- .../tests/test_principal_component.py | 12 ++++---- 6 files changed, 104 insertions(+), 19 deletions(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 401c498f03..741dd20000 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -123,6 +123,9 @@ def generate_sorting( firing_rates=3.0, empty_units=None, refractory_period_ms=3.0, # in ms + add_spikes_on_borders=False, + num_spikes_per_border=3, + border_size_samples=20, seed=None, ): """ @@ -142,6 +145,12 @@ def generate_sorting( List of units that will have no spikes. (used for testing mainly). refractory_period_ms : float, default: 3.0 The refractory period in ms + add_spikes_on_borders : bool, default: False + If True, spikes will be added close to the borders of the segments. + num_spikes_per_border : int, default: 3 + The number of spikes to add close to the borders of the segments. + border_size_samples : int, default: 20 + The size of the border in samples to add border spikes. seed : int, default: None The random seed @@ -151,11 +160,13 @@ def generate_sorting( The sorting object """ seed = _ensure_seed(seed) + rng = np.random.default_rng(seed) num_segments = len(durations) unit_ids = np.arange(num_units) spikes = [] for segment_index in range(num_segments): + num_samples = int(sampling_frequency * durations[segment_index]) times, labels = synthesize_random_firings( num_units=num_units, sampling_frequency=sampling_frequency, @@ -175,7 +186,23 @@ def generate_sorting( spikes_in_seg["unit_index"] = labels spikes_in_seg["segment_index"] = segment_index spikes.append(spikes_in_seg) + + if add_spikes_on_borders: + spikes_on_borders = np.zeros(2 * num_spikes_per_border, dtype=minimum_spike_dtype) + spikes_on_borders["segment_index"] = segment_index + spikes_on_borders["unit_index"] = rng.choice(num_units, size=2 * num_spikes_per_border, replace=True) + # at start + spikes_on_borders["sample_index"][:num_spikes_per_border] = rng.integers( + 0, border_size_samples, num_spikes_per_border + ) + # at end + spikes_on_borders["sample_index"][num_spikes_per_border:] = rng.integers( + num_samples - border_size_samples, num_samples, num_spikes_per_border + ) + spikes.append(spikes_on_borders) + spikes = np.concatenate(spikes) + spikes = spikes[np.lexsort((spikes["sample_index"], spikes["segment_index"]))] sorting = NumpySorting(spikes, sampling_frequency, unit_ids) @@ -596,6 +623,7 @@ def __init__( dtype = np.dtype(dtype).name # Cast to string for serialization if dtype not in ("float32", "float64"): raise ValueError(f"'dtype' must be 'float32' or 'float64' but is {dtype}") + assert strategy in ("tile_pregenerated", "on_the_fly"), "'strategy' must be 'tile_pregenerated' or 'on_the_fly'" BaseRecording.__init__(self, sampling_frequency=sampling_frequency, channel_ids=channel_ids, dtype=dtype) diff --git a/src/spikeinterface/core/tests/test_generate.py b/src/spikeinterface/core/tests/test_generate.py index 9ba5de42d6..3844e421ac 100644 --- a/src/spikeinterface/core/tests/test_generate.py +++ b/src/spikeinterface/core/tests/test_generate.py @@ -26,15 +26,38 @@ def test_generate_recording(): - # TODO even this is extenssivly tested in all other function + # TODO even this is extensively tested in all other functions pass def test_generate_sorting(): - # TODO even this is extenssivly tested in all other function + # TODO even this is extensively tested in all other functions pass +def test_generate_sorting_with_spikes_on_borders(): + num_spikes_on_borders = 10 + border_size_samples = 10 + segment_duration = 10 + for nseg in [1, 2, 3]: + sorting = generate_sorting( + durations=[segment_duration] * nseg, + sampling_frequency=30000, + num_units=10, + add_spikes_on_borders=True, + num_spikes_per_border=num_spikes_on_borders, + border_size_samples=border_size_samples, + ) + spikes = sorting.to_spike_vector(concatenated=False) + # at least num_border spikes at borders for all segments + for i, spikes_in_segment in enumerate(spikes): + num_samples = int(segment_duration * 30000) + assert np.sum(spikes_in_segment["sample_index"] < border_size_samples) >= num_spikes_on_borders + assert ( + np.sum(spikes_in_segment["sample_index"] >= num_samples - border_size_samples) >= num_spikes_on_borders + ) + + def measure_memory_allocation(measure_in_process: bool = True) -> float: """ A local utility to measure memory allocation at a specific point in time. @@ -399,7 +422,7 @@ def test_generate_ground_truth_recording(): if __name__ == "__main__": strategy = "tile_pregenerated" # strategy = "on_the_fly" - test_noise_generator_memory() + # test_noise_generator_memory() # test_noise_generator_under_giga() # test_noise_generator_correct_shape(strategy) # test_noise_generator_consistency_across_calls(strategy, 0, 5) @@ -410,3 +433,4 @@ def test_generate_ground_truth_recording(): # test_generate_templates() # test_inject_templates() # test_generate_ground_truth_recording() + test_generate_sorting_with_spikes_on_borders() diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 3eac333781..c86337a30d 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -16,6 +16,7 @@ class AmplitudeScalingsCalculator(BaseWaveformExtractorExtension): """ extension_name = "amplitude_scalings" + handle_sparsity = True def __init__(self, waveform_extractor): BaseWaveformExtractorExtension.__init__(self, waveform_extractor) @@ -357,7 +358,7 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) continue unit_index = spike["unit_index"] sample_index = spike["sample_index"] - sparse_indices = sparsity_mask[unit_index] + (sparse_indices,) = np.nonzero(sparsity_mask[unit_index]) template = all_templates[unit_index][:, sparse_indices] template = template[nbefore - cut_out_before : nbefore + cut_out_after] sample_centered = sample_index - start_frame @@ -368,7 +369,7 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) template = template[cut_out_before - sample_index :] elif sample_index + cut_out_after > end_frame + right: local_waveform = traces_with_margin[cut_out_start:, sparse_indices] - template = template[: -(sample_index + cut_out_after - end_frame - right)] + template = template[: -(sample_index + cut_out_after - (end_frame + right))] else: local_waveform = traces_with_margin[cut_out_start:cut_out_end, sparse_indices] assert template.shape == local_waveform.shape @@ -550,10 +551,11 @@ def fit_collision( sample_last_centered = np.max(collision["sample_index"]) - (start_frame - left) # construct sparsity as union between units' sparsity - sparse_indices = np.zeros(sparsity_mask.shape[1], dtype="int") + common_sparse_mask = np.zeros(sparsity_mask.shape[1], dtype="int") for spike in collision: - sparse_indices_i = sparsity_mask[spike["unit_index"]] - sparse_indices = np.logical_or(sparse_indices, sparse_indices_i) + mask_i = sparsity_mask[spike["unit_index"]] + common_sparse_mask = np.logical_or(common_sparse_mask, mask_i) + (sparse_indices,) = np.nonzero(common_sparse_mask) local_waveform_start = max(0, sample_first_centered - cut_out_before) local_waveform_end = min(traces_with_margin.shape[0], sample_last_centered + cut_out_after) diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index 233625e09e..1214b84ac4 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -84,9 +84,16 @@ def get_projections(self, unit_id): Returns ------- proj: np.array - The PCA projections (num_waveforms, num_components, num_channels) + The PCA projections (num_waveforms, num_components, num_channels). + In case sparsity is used, only the projections on sparse channels are returned. """ - return self._extension_data[f"pca_{unit_id}"] + projections = self._extension_data[f"pca_{unit_id}"] + mode = self._params["mode"] + if mode in ("by_channel_local", "by_channel_global"): + sparsity = self.get_sparsity() + if sparsity is not None: + projections = projections[:, :, sparsity.unit_id_to_channel_indices[unit_id]] + return projections def get_pca_model(self): """ @@ -211,6 +218,10 @@ def project_new(self, new_waveforms, unit_id=None): wfs_flat = new_waveforms.reshape(new_waveforms.shape[0], -1) projections = pca_model.transform(wfs_flat) + # take care of sparsity (not in case of concatenated) + if mode in ("by_channel_local", "by_channel_global"): + if sparsity is not None: + projections = projections[:, :, sparsity.unit_id_to_channel_indices[unit_id]] return projections def get_sparsity(self): diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index b9c72f9b99..8657d1dced 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -5,7 +5,7 @@ from pathlib import Path from spikeinterface import extract_waveforms, load_extractor, compute_sparsity -from spikeinterface.extractors import toy_example +from spikeinterface.core.generate import generate_ground_truth_recording if hasattr(pytest, "global_test_folder"): cache_folder = pytest.global_test_folder / "postprocessing" @@ -26,7 +26,18 @@ def setUp(self): self.cache_folder = cache_folder # 1-segment - recording, sorting = toy_example(num_segments=1, num_units=10, num_channels=12) + recording, sorting = generate_ground_truth_recording( + durations=[10], + sampling_frequency=30000, + num_channels=12, + num_units=10, + dtype="float32", + seed=91, + generate_sorting_kwargs=dict(add_spikes_on_borders=True), + noise_kwargs=dict(noise_level=10.0, strategy="tile_pregenerated"), + ) + + # add gains and offsets and save gain = 0.1 recording.set_channel_gains(gain) recording.set_channel_offsets(0) @@ -53,7 +64,16 @@ def setUp(self): self.sparsity1 = compute_sparsity(we1, method="radius", radius_um=50) # 2-segments - recording, sorting = toy_example(num_segments=2, num_units=10) + recording, sorting = generate_ground_truth_recording( + durations=[10, 5], + sampling_frequency=30000, + num_channels=12, + num_units=10, + dtype="float32", + seed=91, + generate_sorting_kwargs=dict(add_spikes_on_borders=True), + noise_kwargs=dict(noise_level=10.0, strategy="tile_pregenerated"), + ) recording.set_channel_gains(gain) recording.set_channel_offsets(0) if (cache_folder / "toy_rec_2seg").is_dir(): diff --git a/src/spikeinterface/postprocessing/tests/test_principal_component.py b/src/spikeinterface/postprocessing/tests/test_principal_component.py index 5d64525b52..04ce42b70e 100644 --- a/src/spikeinterface/postprocessing/tests/test_principal_component.py +++ b/src/spikeinterface/postprocessing/tests/test_principal_component.py @@ -87,13 +87,13 @@ def test_sparse(self): pc.run() for i, unit_id in enumerate(unit_ids): proj = pc.get_projections(unit_id) - assert proj.shape[1:] == (5, 4) + assert proj.shape[1:] == (5, len(sparsity.unit_id_to_channel_ids[unit_id])) # test project_new unit_id = 3 new_wfs = we.get_waveforms(unit_id) new_proj = pc.project_new(new_wfs, unit_id=unit_id) - assert new_proj.shape == (new_wfs.shape[0], 5, 4) + assert new_proj.shape == (new_wfs.shape[0], 5, len(sparsity.unit_id_to_channel_ids[unit_id])) if DEBUG: import matplotlib.pyplot as plt @@ -197,8 +197,8 @@ def test_project_new(self): if __name__ == "__main__": test = PrincipalComponentsExtensionTest() test.setUp() - test.test_extension() - test.test_shapes() - test.test_compute_for_all_spikes() + # test.test_extension() + # test.test_shapes() + # test.test_compute_for_all_spikes() test.test_sparse() - test.test_project_new() + # test.test_project_new() From 73ceaacefecc4426d994ebca4ca006d667dada42 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 22 Sep 2023 12:06:15 +0200 Subject: [PATCH 09/26] Extend PCA to be able to return sparse projections and fix tests --- .../postprocessing/principal_component.py | 16 ++++++++++------ .../tests/test_principal_component.py | 12 ++++++++---- .../tests/test_quality_metric_calculator.py | 7 ++++--- 3 files changed, 22 insertions(+), 13 deletions(-) diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index 5d62216c20..8383dcbb43 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -72,7 +72,7 @@ def _select_extension_data(self, unit_ids): new_extension_data[k] = v return new_extension_data - def get_projections(self, unit_id): + def get_projections(self, unit_id, sparse=False): """ Returns the computed projections for the sampled waveforms of a unit id. @@ -80,16 +80,18 @@ def get_projections(self, unit_id): ---------- unit_id : int or str The unit id to return PCA projections for + sparse: bool, default False + If True, and sparsity is not None, only projections on sparse channels are returned. Returns ------- - proj: np.array + projections: np.array The PCA projections (num_waveforms, num_components, num_channels). In case sparsity is used, only the projections on sparse channels are returned. """ projections = self._extension_data[f"pca_{unit_id}"] mode = self._params["mode"] - if mode in ("by_channel_local", "by_channel_global"): + if mode in ("by_channel_local", "by_channel_global") and sparse: sparsity = self.get_sparsity() if sparsity is not None: projections = projections[:, :, sparsity.unit_id_to_channel_indices[unit_id]] @@ -141,7 +143,7 @@ def get_all_projections(self, channel_ids=None, unit_ids=None, outputs="id"): all_labels = [] #  can be unit_id or unit_index all_projections = [] for unit_index, unit_id in enumerate(unit_ids): - proj = self.get_projections(unit_id) + proj = self.get_projections(unit_id, sparse=False) if channel_ids is not None: chan_inds = self.waveform_extractor.channel_ids_to_indices(channel_ids) proj = proj[:, :, chan_inds] @@ -158,7 +160,7 @@ def get_all_projections(self, channel_ids=None, unit_ids=None, outputs="id"): return all_labels, all_projections - def project_new(self, new_waveforms, unit_id=None): + def project_new(self, new_waveforms, unit_id=None, sparse=False): """ Projects new waveforms or traces snippets on the PC components. @@ -168,6 +170,8 @@ def project_new(self, new_waveforms, unit_id=None): Array with new waveforms to project with shape (num_waveforms, num_samples, num_channels) unit_id: int or str In case PCA is sparse and mode is by_channel_local, the unit_id of 'new_waveforms' + sparse: bool, default: False + If True, and sparsity is not None, only projections on sparse channels are returned. Returns ------- @@ -219,7 +223,7 @@ def project_new(self, new_waveforms, unit_id=None): projections = pca_model.transform(wfs_flat) # take care of sparsity (not in case of concatenated) - if mode in ("by_channel_local", "by_channel_global"): + if mode in ("by_channel_local", "by_channel_global") and sparse: if sparsity is not None: projections = projections[:, :, sparsity.unit_id_to_channel_indices[unit_id]] return projections diff --git a/src/spikeinterface/postprocessing/tests/test_principal_component.py b/src/spikeinterface/postprocessing/tests/test_principal_component.py index 04ce42b70e..49591d9b89 100644 --- a/src/spikeinterface/postprocessing/tests/test_principal_component.py +++ b/src/spikeinterface/postprocessing/tests/test_principal_component.py @@ -86,14 +86,18 @@ def test_sparse(self): pc.set_params(n_components=5, mode=mode, sparsity=sparsity) pc.run() for i, unit_id in enumerate(unit_ids): - proj = pc.get_projections(unit_id) - assert proj.shape[1:] == (5, len(sparsity.unit_id_to_channel_ids[unit_id])) + proj_sparse = pc.get_projections(unit_id, sparse=True) + assert proj_sparse.shape[1:] == (5, len(sparsity.unit_id_to_channel_ids[unit_id])) + proj_dense = pc.get_projections(unit_id, sparse=False) + assert proj_dense.shape[1:] == (5, num_channels) # test project_new unit_id = 3 new_wfs = we.get_waveforms(unit_id) - new_proj = pc.project_new(new_wfs, unit_id=unit_id) - assert new_proj.shape == (new_wfs.shape[0], 5, len(sparsity.unit_id_to_channel_ids[unit_id])) + new_proj_sparse = pc.project_new(new_wfs, unit_id=unit_id, sparse=True) + assert new_proj_sparse.shape == (new_wfs.shape[0], 5, len(sparsity.unit_id_to_channel_ids[unit_id])) + new_proj_dense = pc.project_new(new_wfs, unit_id=unit_id, sparse=False) + assert new_proj_dense.shape == (new_wfs.shape[0], 5, num_channels) if DEBUG: import matplotlib.pyplot as plt diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py index 4fa65993d1..977beca210 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py @@ -261,7 +261,8 @@ def test_nn_metrics(self): we_sparse, metric_names=metric_names, sparsity=None, seed=0, n_jobs=2 ) for metric_name in metrics.columns: - assert np.allclose(metrics[metric_name], metrics_par[metric_name]) + # NaNs are skipped + assert np.allclose(metrics[metric_name].dropna(), metrics_par[metric_name].dropna()) def test_recordingless(self): we = self.we_long @@ -305,7 +306,7 @@ def test_empty_units(self): test.setUp() # test.test_drift_metrics() # test.test_extension() - # test.test_nn_metrics() + test.test_nn_metrics() # test.test_peak_sign() # test.test_empty_units() - test.test_recordingless() + # test.test_recordingless() From b9b6c15b42a64d877ea9fad9fca84424e2c97edf Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 22 Sep 2023 12:12:21 +0200 Subject: [PATCH 10/26] Add test to check correct order of spikes with borders --- src/spikeinterface/core/tests/test_generate.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/tests/test_generate.py b/src/spikeinterface/core/tests/test_generate.py index 3844e421ac..9a9c61766f 100644 --- a/src/spikeinterface/core/tests/test_generate.py +++ b/src/spikeinterface/core/tests/test_generate.py @@ -48,9 +48,15 @@ def test_generate_sorting_with_spikes_on_borders(): num_spikes_per_border=num_spikes_on_borders, border_size_samples=border_size_samples, ) + # check that segments are correctly sorted + all_spikes = sorting.to_spike_vector() + np.testing.assert_array_equal(all_spikes["segment_index"], np.sort(all_spikes["segment_index"])) + spikes = sorting.to_spike_vector(concatenated=False) # at least num_border spikes at borders for all segments - for i, spikes_in_segment in enumerate(spikes): + for spikes_in_segment in spikes: + # check that sample indices are correctly sorted within segments + np.testing.assert_array_equal(spikes_in_segment["sample_index"], np.sort(spikes_in_segment["sample_index"])) num_samples = int(segment_duration * 30000) assert np.sum(spikes_in_segment["sample_index"] < border_size_samples) >= num_spikes_on_borders assert ( From 4e79b5811d41e6343391a3a6b26fab97f657368b Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 22 Sep 2023 13:32:51 +0200 Subject: [PATCH 11/26] propagate UnitSelector to others ipywidgets --- src/spikeinterface/widgets/amplitudes.py | 12 ++- src/spikeinterface/widgets/base.py | 3 +- src/spikeinterface/widgets/metrics.py | 21 ++-- src/spikeinterface/widgets/spike_locations.py | 34 +++---- .../widgets/spikes_on_traces.py | 87 ++++++++++------- src/spikeinterface/widgets/unit_locations.py | 29 +++--- src/spikeinterface/widgets/unit_waveforms.py | 50 +++++----- .../widgets/utils_ipywidgets.py | 96 ------------------- 8 files changed, 121 insertions(+), 211 deletions(-) diff --git a/src/spikeinterface/widgets/amplitudes.py b/src/spikeinterface/widgets/amplitudes.py index b60de98cb0..5aa090b1b4 100644 --- a/src/spikeinterface/widgets/amplitudes.py +++ b/src/spikeinterface/widgets/amplitudes.py @@ -147,13 +147,16 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): else: bins = dp.bins ax_hist = self.axes.flatten()[1] - ax_hist.hist(amps, bins=bins, orientation="horizontal", color=dp.unit_colors[unit_id], alpha=0.8) + # this is super slow, using plot and np.histogram is really much faster (and nicer!) + # ax_hist.hist(amps, bins=bins, orientation="horizontal", color=dp.unit_colors[unit_id], alpha=0.8) + count, bins = np.histogram(amps, bins=bins) + ax_hist.plot(count, bins[:-1], color=dp.unit_colors[unit_id], alpha=0.8) if dp.plot_histograms: ax_hist = self.axes.flatten()[1] ax_hist.set_ylim(scatter_ax.get_ylim()) ax_hist.axis("off") - self.figure.tight_layout() + # self.figure.tight_layout() if dp.plot_legend: if hasattr(self, "legend") and self.legend is not None: @@ -174,7 +177,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): # import ipywidgets.widgets as widgets import ipywidgets.widgets as W from IPython.display import display - from .utils_ipywidgets import check_ipywidget_backend, make_unit_controller, UnitSelector + from .utils_ipywidgets import check_ipywidget_backend, UnitSelector check_ipywidget_backend() @@ -200,7 +203,6 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.checkbox_histograms = W.Checkbox( value=data_plot["plot_histograms"], description="hist", - # disabled=False, ) left_sidebar = W.VBox( @@ -231,6 +233,7 @@ def _full_update_plot(self, change=None): data_plot = self.next_data_plot data_plot["unit_ids"] = self.unit_selector.value data_plot["plot_histograms"] = self.checkbox_histograms.value + data_plot["plot_legend"] = False backend_kwargs = dict(figure=self.figure, axes=None, ax=None) self.plot_matplotlib(data_plot, **backend_kwargs) @@ -243,6 +246,7 @@ def _update_plot(self, change=None): data_plot = self.next_data_plot data_plot["unit_ids"] = self.unit_selector.value data_plot["plot_histograms"] = self.checkbox_histograms.value + data_plot["plot_legend"] = False backend_kwargs = dict(figure=None, axes=self.axes, ax=None) self.plot_matplotlib(data_plot, **backend_kwargs) diff --git a/src/spikeinterface/widgets/base.py b/src/spikeinterface/widgets/base.py index 4ed83fcca9..1ff691320a 100644 --- a/src/spikeinterface/widgets/base.py +++ b/src/spikeinterface/widgets/base.py @@ -38,6 +38,7 @@ def set_default_plotter_backend(backend): "width_cm": "Width of the figure in cm (default 10)", "height_cm": "Height of the figure in cm (default 6)", "display": "If True, widgets are immediately displayed", + # "controllers": "" }, "ephyviewer": {}, } @@ -45,7 +46,7 @@ def set_default_plotter_backend(backend): default_backend_kwargs = { "matplotlib": {"figure": None, "ax": None, "axes": None, "ncols": 5, "figsize": None, "figtitle": None}, "sortingview": {"generate_url": True, "display": True, "figlabel": None, "height": None}, - "ipywidgets": {"width_cm": 25, "height_cm": 10, "display": True}, + "ipywidgets": {"width_cm": 25, "height_cm": 10, "display": True, "controllers": None}, "ephyviewer": {}, } diff --git a/src/spikeinterface/widgets/metrics.py b/src/spikeinterface/widgets/metrics.py index 9dc51f522e..604da35e65 100644 --- a/src/spikeinterface/widgets/metrics.py +++ b/src/spikeinterface/widgets/metrics.py @@ -128,7 +128,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt import ipywidgets.widgets as widgets from IPython.display import display - from .utils_ipywidgets import check_ipywidget_backend, make_unit_controller + from .utils_ipywidgets import check_ipywidget_backend, UnitSelector check_ipywidget_backend() @@ -147,34 +147,29 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): with output: self.figure = plt.figure(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) plt.show() - if data_plot["unit_ids"] is None: - data_plot["unit_ids"] = [] - unit_widget, unit_controller = make_unit_controller( - data_plot["unit_ids"], list(data_plot["unit_colors"].keys()), ratios[0] * width_cm, height_cm - ) - - self.controller = unit_controller + self.unit_selector = UnitSelector(data_plot["sorting"].unit_ids) + self.unit_selector.value = [ ] - for w in self.controller.values(): - w.observe(self._update_ipywidget) self.widget = widgets.AppLayout( center=self.figure.canvas, - left_sidebar=unit_widget, + left_sidebar=self.unit_selector, pane_widths=ratios + [0], ) # a first update self._update_ipywidget(None) + self.unit_selector.observe(self._update_ipywidget, names='value', type="change") + if backend_kwargs["display"]: display(self.widget) def _update_ipywidget(self, change): from matplotlib.lines import Line2D - unit_ids = self.controller["unit_ids"].value + unit_ids = self.unit_selector.value unit_colors = self.data_plot["unit_colors"] # matplotlib next_data_plot dict update at each call @@ -198,6 +193,7 @@ def _update_ipywidget(self, change): self.plot_matplotlib(self.data_plot, **backend_kwargs) if len(unit_ids) > 0: + # TODO later make option to control legend or not for l in self.figure.legends: l.remove() handles = [ @@ -212,6 +208,7 @@ def _update_ipywidget(self, change): self.figure.canvas.draw() self.figure.canvas.flush_events() + def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv from .utils_sortingview import generate_unit_table_view, make_serializable, handle_display_and_url diff --git a/src/spikeinterface/widgets/spike_locations.py b/src/spikeinterface/widgets/spike_locations.py index 9771b2c0e9..926051b8f9 100644 --- a/src/spikeinterface/widgets/spike_locations.py +++ b/src/spikeinterface/widgets/spike_locations.py @@ -191,7 +191,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt import ipywidgets.widgets as widgets from IPython.display import display - from .utils_ipywidgets import check_ipywidget_backend, make_unit_controller + from .utils_ipywidgets import check_ipywidget_backend, UnitSelector check_ipywidget_backend() @@ -210,48 +210,36 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): fig, self.ax = plt.subplots(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) plt.show() - data_plot["unit_ids"] = data_plot["unit_ids"][:1] - - unit_widget, unit_controller = make_unit_controller( - data_plot["unit_ids"], - list(data_plot["unit_colors"].keys()), - ratios[0] * width_cm, - height_cm, - ) - - self.controller = unit_controller - - for w in self.controller.values(): - w.observe(self._update_ipywidget) + self.unit_selector = UnitSelector(data_plot["unit_ids"]) + self.unit_selector.value = list(data_plot["unit_ids"])[:1] self.widget = widgets.AppLayout( center=fig.canvas, - left_sidebar=unit_widget, + left_sidebar=self.unit_selector, pane_widths=ratios + [0], ) # a first update - self._update_ipywidget(None) + self._update_ipywidget() + + self.unit_selector.observe(self._update_ipywidget, names='value', type="change") if backend_kwargs["display"]: display(self.widget) - def _update_ipywidget(self, change): + def _update_ipywidget(self, change=None): self.ax.clear() - unit_ids = self.controller["unit_ids"].value - # matplotlib next_data_plot dict update at each call data_plot = self.next_data_plot - data_plot["unit_ids"] = unit_ids + data_plot["unit_ids"] = self.unit_selector.value data_plot["plot_all_units"] = True + # TODO add an option checkbox for legend data_plot["plot_legend"] = True data_plot["hide_axis"] = True - backend_kwargs = {} - backend_kwargs["ax"] = self.ax + backend_kwargs = dict(ax=self.ax) - # self.mpl_plotter.do_plot(data_plot, **backend_kwargs) self.plot_matplotlib(data_plot, **backend_kwargs) fig = self.ax.get_figure() fig.canvas.draw() diff --git a/src/spikeinterface/widgets/spikes_on_traces.py b/src/spikeinterface/widgets/spikes_on_traces.py index ae036d1ba1..2f748cc0fc 100644 --- a/src/spikeinterface/widgets/spikes_on_traces.py +++ b/src/spikeinterface/widgets/spikes_on_traces.py @@ -149,20 +149,20 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): sorting = we.sorting # first plot time series - ts_widget = TracesWidget(recording, **dp.options, backend="matplotlib", **backend_kwargs) - self.ax = ts_widget.ax - self.axes = ts_widget.axes - self.figure = ts_widget.figure + traces_widget = TracesWidget(recording, **dp.options, backend="matplotlib", **backend_kwargs) + self.ax = traces_widget.ax + self.axes = traces_widget.axes + self.figure = traces_widget.figure ax = self.ax - frame_range = ts_widget.data_plot["frame_range"] - segment_index = ts_widget.data_plot["segment_index"] - min_y = np.min(ts_widget.data_plot["channel_locations"][:, 1]) - max_y = np.max(ts_widget.data_plot["channel_locations"][:, 1]) + frame_range = traces_widget.data_plot["frame_range"] + segment_index = traces_widget.data_plot["segment_index"] + min_y = np.min(traces_widget.data_plot["channel_locations"][:, 1]) + max_y = np.max(traces_widget.data_plot["channel_locations"][:, 1]) - n = len(ts_widget.data_plot["channel_ids"]) - order = ts_widget.data_plot["order"] + n = len(traces_widget.data_plot["channel_ids"]) + order = traces_widget.data_plot["order"] if order is None: order = np.arange(n) @@ -210,13 +210,13 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): # construct waveforms label_set = False if len(spike_frames_to_plot) > 0: - vspacing = ts_widget.data_plot["vspacing"] - traces = ts_widget.data_plot["list_traces"][0] + vspacing = traces_widget.data_plot["vspacing"] + traces = traces_widget.data_plot["list_traces"][0] waveform_idxs = spike_frames_to_plot[:, None] + np.arange(-we.nbefore, we.nafter) - frame_range[0] - waveform_idxs = np.clip(waveform_idxs, 0, len(ts_widget.data_plot["times"]) - 1) + waveform_idxs = np.clip(waveform_idxs, 0, len(traces_widget.data_plot["times"]) - 1) - times = ts_widget.data_plot["times"][waveform_idxs] + times = traces_widget.data_plot["times"][waveform_idxs] # discontinuity times[:, -1] = np.nan @@ -224,7 +224,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): waveforms = traces[waveform_idxs] # [:, :, order] waveforms_r = waveforms.reshape((waveforms.shape[0] * waveforms.shape[1], waveforms.shape[2])) - for i, chan_id in enumerate(ts_widget.data_plot["channel_ids"]): + for i, chan_id in enumerate(traces_widget.data_plot["channel_ids"]): offset = vspacing * i if chan_id in chan_ids: l = ax.plot(times_r, offset + waveforms_r[:, i], color=dp.unit_colors[unit]) @@ -232,13 +232,13 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): handles.append(l[0]) labels.append(unit) label_set = True - ax.legend(handles, labels) + # ax.legend(handles, labels) def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt import ipywidgets.widgets as widgets from IPython.display import display - from .utils_ipywidgets import check_ipywidget_backend, make_unit_controller + from .utils_ipywidgets import check_ipywidget_backend, UnitSelector check_ipywidget_backend() @@ -256,37 +256,58 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): width_cm = backend_kwargs["width_cm"] # plot timeseries - ts_widget = TracesWidget(we.recording, **dp.options, backend="ipywidgets", **backend_kwargs_ts) - self.ax = ts_widget.ax - self.axes = ts_widget.axes - self.figure = ts_widget.figure + self._traces_widget = TracesWidget(we.recording, **dp.options, backend="ipywidgets", **backend_kwargs_ts) + self.ax = self._traces_widget.ax + self.axes = self._traces_widget.axes + self.figure = self._traces_widget.figure - unit_widget, unit_controller = make_unit_controller( - data_plot["unit_ids"], we.unit_ids, ratios[0] * width_cm, height_cm - ) + self.sampling_frequency = self._traces_widget.rec0.sampling_frequency - self.controller = dict() - self.controller.update(ts_widget.controller) - self.controller.update(unit_controller) + self.time_slider = self._traces_widget.time_slider - for w in self.controller.values(): - w.observe(self._update_ipywidget) + self.unit_selector = UnitSelector(data_plot["unit_ids"]) + self.unit_selector.value = list(data_plot["unit_ids"])[:1] - self.widget = widgets.AppLayout(center=ts_widget.widget, left_sidebar=unit_widget, pane_widths=ratios + [0]) + self.widget = widgets.AppLayout(center=self._traces_widget.widget, + left_sidebar=self.unit_selector, + pane_widths=ratios + [0]) # a first update - self._update_ipywidget(None) + self._update_ipywidget() + + # remove callback from traces_widget + self.unit_selector.observe(self._update_ipywidget, names='value', type="change") + self._traces_widget.time_slider.observe(self._update_ipywidget, names='value', type="change") + self._traces_widget.channel_selector.observe(self._update_ipywidget, names='value', type="change") + self._traces_widget.scaler.observe(self._update_ipywidget, names='value', type="change") + if backend_kwargs["display"]: display(self.widget) - def _update_ipywidget(self, change): + def _update_ipywidget(self, change=None): self.ax.clear() - unit_ids = self.controller["unit_ids"].value + # TODO later: this is still a bit buggy because it make double refresh one from _traces_widget and one internal + + unit_ids = self.unit_selector.value + start_frame, end_frame, segment_index = self._traces_widget.time_slider.value + channel_ids = self._traces_widget.channel_selector.value + mode = self._traces_widget.mode_selector.value data_plot = self.next_data_plot data_plot["unit_ids"] = unit_ids + data_plot["options"].update( + dict( + channel_ids=channel_ids, + segment_index=segment_index, + # frame_range=(start_frame, end_frame), + time_range=np.array([start_frame, end_frame]) / self.sampling_frequency, + mode=mode, + with_colorbar=False, + ) + ) + backend_kwargs = {} backend_kwargs["ax"] = self.ax diff --git a/src/spikeinterface/widgets/unit_locations.py b/src/spikeinterface/widgets/unit_locations.py index 42267e711f..8526a95d60 100644 --- a/src/spikeinterface/widgets/unit_locations.py +++ b/src/spikeinterface/widgets/unit_locations.py @@ -167,7 +167,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt import ipywidgets.widgets as widgets from IPython.display import display - from .utils_ipywidgets import check_ipywidget_backend, make_unit_controller + from .utils_ipywidgets import check_ipywidget_backend, UnitSelector check_ipywidget_backend() @@ -186,42 +186,35 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): fig, self.ax = plt.subplots(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) plt.show() - data_plot["unit_ids"] = data_plot["unit_ids"][:1] - unit_widget, unit_controller = make_unit_controller( - data_plot["unit_ids"], list(data_plot["unit_colors"].keys()), ratios[0] * width_cm, height_cm - ) - - self.controller = unit_controller - - for w in self.controller.values(): - w.observe(self._update_ipywidget) + self.unit_selector = UnitSelector(data_plot["unit_ids"]) + self.unit_selector.value = list(data_plot["unit_ids"])[:1] self.widget = widgets.AppLayout( center=fig.canvas, - left_sidebar=unit_widget, + left_sidebar=self.unit_selector, pane_widths=ratios + [0], ) # a first update - self._update_ipywidget(None) + self._update_ipywidget() + + self.unit_selector.observe(self._update_ipywidget, names='value', type="change") if backend_kwargs["display"]: display(self.widget) - def _update_ipywidget(self, change): + def _update_ipywidget(self, change=None): self.ax.clear() - unit_ids = self.controller["unit_ids"].value - # matplotlib next_data_plot dict update at each call data_plot = self.next_data_plot - data_plot["unit_ids"] = unit_ids + data_plot["unit_ids"] = self.unit_selector.value data_plot["plot_all_units"] = True + # TODO later add an option checkbox for legend data_plot["plot_legend"] = True data_plot["hide_axis"] = True - backend_kwargs = {} - backend_kwargs["ax"] = self.ax + backend_kwargs = dict(ax=self.ax) self.plot_matplotlib(data_plot, **backend_kwargs) fig = self.ax.get_figure() diff --git a/src/spikeinterface/widgets/unit_waveforms.py b/src/spikeinterface/widgets/unit_waveforms.py index e64765b44b..f01c842b66 100644 --- a/src/spikeinterface/widgets/unit_waveforms.py +++ b/src/spikeinterface/widgets/unit_waveforms.py @@ -250,7 +250,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt import ipywidgets.widgets as widgets from IPython.display import display - from .utils_ipywidgets import check_ipywidget_backend, make_unit_controller + from .utils_ipywidgets import check_ipywidget_backend, UnitSelector check_ipywidget_backend() @@ -274,44 +274,33 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.fig_probe, self.ax_probe = plt.subplots(figsize=((ratios[2] * width_cm) * cm, height_cm * cm)) plt.show() - data_plot["unit_ids"] = data_plot["unit_ids"][:1] - unit_widget, unit_controller = make_unit_controller( - data_plot["unit_ids"], we.unit_ids, ratios[0] * width_cm, height_cm - ) + self.unit_selector = UnitSelector(data_plot["unit_ids"]) + self.unit_selector.value = list(data_plot["unit_ids"])[:1] + - same_axis_button = widgets.Checkbox( + self.same_axis_button = widgets.Checkbox( value=False, description="same axis", disabled=False, ) - plot_templates_button = widgets.Checkbox( + self.plot_templates_button = widgets.Checkbox( value=True, description="plot templates", disabled=False, ) - hide_axis_button = widgets.Checkbox( + self.hide_axis_button = widgets.Checkbox( value=True, description="hide axis", disabled=False, ) - footer = widgets.HBox([same_axis_button, plot_templates_button, hide_axis_button]) - - self.controller = { - "same_axis": same_axis_button, - "plot_templates": plot_templates_button, - "hide_axis": hide_axis_button, - } - self.controller.update(unit_controller) - - for w in self.controller.values(): - w.observe(self._update_ipywidget) + footer = widgets.HBox([self.same_axis_button, self.plot_templates_button, self.hide_axis_button]) self.widget = widgets.AppLayout( center=self.fig_wf.canvas, - left_sidebar=unit_widget, + left_sidebar=self.unit_selector, right_sidebar=self.fig_probe.canvas, pane_widths=ratios, footer=footer, @@ -320,6 +309,11 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): # a first update self._update_ipywidget(None) + self.unit_selector.observe(self._update_ipywidget, names='value', type="change") + for w in self.same_axis_button, self.plot_templates_button, self.hide_axis_button: + w.observe(self._update_ipywidget, names='value', type="change") + + if backend_kwargs["display"]: display(self.widget) @@ -327,10 +321,15 @@ def _update_ipywidget(self, change): self.fig_wf.clear() self.ax_probe.clear() - unit_ids = self.controller["unit_ids"].value - same_axis = self.controller["same_axis"].value - plot_templates = self.controller["plot_templates"].value - hide_axis = self.controller["hide_axis"].value + # unit_ids = self.controller["unit_ids"].value + unit_ids = self.unit_selector.value + # same_axis = self.controller["same_axis"].value + # plot_templates = self.controller["plot_templates"].value + # hide_axis = self.controller["hide_axis"].value + + same_axis = self.same_axis_button.value + plot_templates = self.plot_templates_button.value + hide_axis = self.hide_axis_button.value # matplotlib next_data_plot dict update at each call data_plot = self.next_data_plot @@ -341,6 +340,8 @@ def _update_ipywidget(self, change): data_plot["plot_templates"] = plot_templates if data_plot["plot_waveforms"]: data_plot["wfs_by_ids"] = {unit_id: self.we.get_waveforms(unit_id) for unit_id in unit_ids} + + # TODO option for plot_legend backend_kwargs = {} @@ -369,6 +370,7 @@ def _update_ipywidget(self, change): self.ax_probe.axis("off") self.ax_probe.axis("equal") + # TODO this could be done with probeinterface plotting plotting tools!! for unit in unit_ids: channel_inds = data_plot["sparsity"].unit_id_to_channel_indices[unit] self.ax_probe.plot( diff --git a/src/spikeinterface/widgets/utils_ipywidgets.py b/src/spikeinterface/widgets/utils_ipywidgets.py index d2c41f234a..57550c0910 100644 --- a/src/spikeinterface/widgets/utils_ipywidgets.py +++ b/src/spikeinterface/widgets/utils_ipywidgets.py @@ -11,102 +11,6 @@ def check_ipywidget_backend(): assert "ipympl" in mpl_backend, "To use the 'ipywidgets' backend, you have to set %matplotlib widget" -# def make_timeseries_controller(t_start, t_stop, layer_keys, num_segments, time_range, mode, all_layers, width_cm): -# time_slider = W.FloatSlider( -# orientation="horizontal", -# description="time:", -# value=time_range[0], -# min=t_start, -# max=t_stop, -# continuous_update=False, -# layout=W.Layout(width=f"{width_cm}cm"), -# ) -# layer_selector = W.Dropdown(description="layer", options=layer_keys) -# segment_selector = W.Dropdown(description="segment", options=list(range(num_segments))) -# window_sizer = W.BoundedFloatText(value=np.diff(time_range)[0], step=0.1, min=0.005, description="win (s)") -# mode_selector = W.Dropdown(options=["line", "map"], description="mode", value=mode) -# all_layers = W.Checkbox(description="plot all layers", value=all_layers) - -# controller = { -# "layer_key": layer_selector, -# "segment_index": segment_selector, -# "window": window_sizer, -# "t_start": time_slider, -# "mode": mode_selector, -# "all_layers": all_layers, -# } -# widget = W.VBox( -# [time_slider, W.HBox([all_layers, layer_selector, segment_selector, window_sizer, mode_selector])] -# ) - -# return widget, controller - - -def make_unit_controller(unit_ids, all_unit_ids, width_cm, height_cm): - unit_label = W.Label(value="units:") - - unit_selector = W.SelectMultiple( - options=all_unit_ids, - value=list(unit_ids), - disabled=False, - layout=W.Layout(width=f"{width_cm}cm", height=f"{height_cm}cm"), - ) - - controller = {"unit_ids": unit_selector} - widget = W.VBox([unit_label, unit_selector]) - - return widget, controller - - -# def make_channel_controller(recording, width_cm, height_cm): -# channel_label = W.Label("channel indices:", layout=W.Layout(justify_content="center")) -# channel_selector = W.IntRangeSlider( -# value=[0, recording.get_num_channels()], -# min=0, -# max=recording.get_num_channels(), -# step=1, -# disabled=False, -# continuous_update=False, -# orientation="vertical", -# readout=True, -# readout_format="d", -# layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{height_cm}cm"), -# ) - -# controller = {"channel_inds": channel_selector} -# widget = W.VBox([channel_label, channel_selector]) - -# return widget, controller - - -# def make_scale_controller(width_cm, height_cm): -# scale_label = W.Label("Scale", layout=W.Layout(justify_content="center")) - -# plus_selector = W.Button( -# description="", -# disabled=False, -# button_style="", # 'success', 'info', 'warning', 'danger' or '' -# tooltip="Increase scale", -# icon="arrow-up", -# layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), -# ) - -# minus_selector = W.Button( -# description="", -# disabled=False, -# button_style="", # 'success', 'info', 'warning', 'danger' or '' -# tooltip="Decrease scale", -# icon="arrow-down", -# layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), -# ) - -# controller = {"plus": plus_selector, "minus": minus_selector} -# widget = W.VBox([scale_label, plus_selector, minus_selector]) - -# return widget, controller - - - class TimeSlider(W.HBox): value = traitlets.Tuple(traitlets.Int(), traitlets.Int(), traitlets.Int()) From f315594b0b88bed01f01232688d62c4c2e4bc0fe Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 22 Sep 2023 15:49:47 +0200 Subject: [PATCH 12/26] protect TimeSlider on the upper limit to avoid border effect on window size --- src/spikeinterface/widgets/utils_ipywidgets.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/widgets/utils_ipywidgets.py b/src/spikeinterface/widgets/utils_ipywidgets.py index 57550c0910..ee6133a990 100644 --- a/src/spikeinterface/widgets/utils_ipywidgets.py +++ b/src/spikeinterface/widgets/utils_ipywidgets.py @@ -54,7 +54,7 @@ def __init__(self, durations, sampling_frequency, time_range=(0, 1.), **kwargs): # description='time:', value=start_frame, min=0, - max=self.frame_limits[self.segment_index], + max=self.frame_limits[self.segment_index] - 1, readout=False, continuous_update=False, layout=W.Layout(width=f'70%') @@ -112,10 +112,13 @@ def update_time(self, new_frame=None, new_time=None, update_slider=False, update else: start_frame = new_frame delta_s = self.window_sizer.value - end_frame = start_frame + int(delta_s * self.sampling_frequency) - + delta = int(delta_s * self.sampling_frequency) + # clip + start_frame = min(self.frame_limits[self.segment_index] - delta, start_frame) start_frame = max(0, start_frame) + end_frame = start_frame + delta + end_frame = min(self.frame_limits[self.segment_index], end_frame) @@ -170,7 +173,7 @@ def segment_changed(self, change=None): self.slider.unobserve(self.slider_moved, names='value', type="change") # self.slider.value = 0 - self.slider.max = self.frame_limits[self.segment_index] + self.slider.max = self.frame_limits[self.segment_index] - 1 self.slider.observe(self.slider_moved, names='value', type="change") self.update_time(new_frame=0, update_slider=True, update_label=True) From 2ba8928b785ed06f8a2f01b48ea632a4171ab926 Mon Sep 17 00:00:00 2001 From: Windows Home Date: Sun, 24 Sep 2023 13:51:48 -0500 Subject: [PATCH 13/26] Fix unit ID matching in sortingview curation Refine the logic for matching unit IDs in the sortingview curation process. Instead of using a potentially ambiguous containment check, unit IDs are now split at the '-' character, ensuring accurate mapping between unit labels and merged unit IDs. Additionally, introduced a unit test to validate the improved behavior and guard against potential false positives in future changes. --- .../curation/sortingview_curation.py | 3 +- .../tests/test_sortingview_curation.py | 45 +++++++++++++++++-- 2 files changed, 44 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/curation/sortingview_curation.py b/src/spikeinterface/curation/sortingview_curation.py index 6adf9effd4..f595a67a3f 100644 --- a/src/spikeinterface/curation/sortingview_curation.py +++ b/src/spikeinterface/curation/sortingview_curation.py @@ -83,8 +83,9 @@ def apply_sortingview_curation( properties[label] = np.zeros(len(curation_sorting.current_sorting.unit_ids), dtype=bool) for u_i, unit_id in enumerate(curation_sorting.current_sorting.unit_ids): labels_unit = [] + unit_id_parts = str(unit_id).split('-') for unit_label, labels in labels_dict.items(): - if unit_label in str(unit_id): + if unit_label in unit_id_parts: labels_unit.extend(labels) for label in labels_unit: properties[label][u_i] = True diff --git a/src/spikeinterface/curation/tests/test_sortingview_curation.py b/src/spikeinterface/curation/tests/test_sortingview_curation.py index 9177cb5536..1b9e6f2800 100644 --- a/src/spikeinterface/curation/tests/test_sortingview_curation.py +++ b/src/spikeinterface/curation/tests/test_sortingview_curation.py @@ -1,8 +1,10 @@ import pytest from pathlib import Path import os - +import json +import numpy as np import spikeinterface as si +import spikeinterface.extractors as se from spikeinterface.extractors import read_mearec from spikeinterface import set_global_tmp_folder from spikeinterface.postprocessing import ( @@ -17,9 +19,7 @@ cache_folder = pytest.global_test_folder / "curation" else: cache_folder = Path("cache_folder") / "curation" - parent_folder = Path(__file__).parent - ON_GITHUB = bool(os.getenv("GITHUB_ACTIONS")) KACHERY_CLOUD_SET = bool(os.getenv("KACHERY_CLOUD_CLIENT_ID")) and bool(os.getenv("KACHERY_CLOUD_PRIVATE_KEY")) @@ -111,6 +111,7 @@ def test_json_curation(): # from curation.json json_file = parent_folder / "sv-sorting-curation.json" sorting_curated_json = apply_sortingview_curation(sorting, uri_or_json=json_file, verbose=True) + print(f"Sorting: {sorting.get_unit_ids()}") print(f"From JSON: {sorting_curated_json}") assert len(sorting_curated_json.unit_ids) == 9 @@ -130,9 +131,47 @@ def test_json_curation(): assert len(sorting_curated_json_mua.unit_ids) == 6 assert len(sorting_curated_json_mua1.unit_ids) == 5 +def test_false_positive_curation(): + # https://spikeinterface.readthedocs.io/en/latest/modules_gallery/core/plot_2_sorting_extractor.html + sampling_frequency = 30000. + duration = 20. + num_timepoints = int(sampling_frequency * duration) + num_units = 20 + num_spikes = 1000 + times0 = np.int_(np.sort(np.random.uniform(0, num_timepoints, num_spikes))) + labels0 = np.random.randint(1, num_units + 1, size=num_spikes) + times1 = np.int_(np.sort(np.random.uniform(0, num_timepoints, num_spikes))) + labels1 = np.random.randint(1, num_units + 1, size=num_spikes) + + sorting = se.NumpySorting.from_times_labels([times0, times1], [labels0, labels1], sampling_frequency) + print('Sorting: {}'.format(sorting.get_unit_ids())) + + # Test curation JSON: + test_json = { + "labelsByUnit": { + "1": ["accept"], + }, + "mergeGroups": [] + } + + json_path = "test_data.json" + with open(json_path, 'w') as f: + json.dump(test_json, f, indent=4) + + sorting_curated = apply_sortingview_curation(sorting, uri_or_json=json_path, verbose=True) + accept_idx = np.where(sorting_curated.get_property("accept"))[0] + sorting_curated_ids = sorting_curated.get_unit_ids() + print(f'Accepted unit IDs: {sorting_curated_ids[accept_idx]}') + + # Check if unit_id 1 has received the "accept" label. + assert sorting_curated.get_unit_property(unit_id=1, key="accept") + # Check if unit_id "#10" has received the "accept" label. + # If so, test fails since only unit_id 1 received the "accept" label in test_json. + assert not sorting_curated.get_unit_property(unit_id=10, key="accept") if __name__ == "__main__": # generate_sortingview_curation_dataset() test_sha1_curation() test_gh_curation() test_json_curation() + test_false_positive_curation() From 45c69f52147edd406f293f731b7c7c687c700d29 Mon Sep 17 00:00:00 2001 From: Windows Home Date: Sun, 24 Sep 2023 14:46:01 -0500 Subject: [PATCH 14/26] Add merge check --- .gitignore | 1 + .../tests/test_sortingview_curation.py | 20 ++++++++++++------- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/.gitignore b/.gitignore index 3ee3cb8867..7838213bed 100644 --- a/.gitignore +++ b/.gitignore @@ -188,3 +188,4 @@ test_folder/ # Mac OS .DS_Store +test_data.json diff --git a/src/spikeinterface/curation/tests/test_sortingview_curation.py b/src/spikeinterface/curation/tests/test_sortingview_curation.py index 1b9e6f2800..c8a0788223 100644 --- a/src/spikeinterface/curation/tests/test_sortingview_curation.py +++ b/src/spikeinterface/curation/tests/test_sortingview_curation.py @@ -115,6 +115,7 @@ def test_json_curation(): print(f"From JSON: {sorting_curated_json}") assert len(sorting_curated_json.unit_ids) == 9 + print(sorting_curated_json.unit_ids) assert "#8-#9" in sorting_curated_json.unit_ids assert "accept" in sorting_curated_json.get_property_keys() assert "mua" in sorting_curated_json.get_property_keys() @@ -150,24 +151,29 @@ def test_false_positive_curation(): test_json = { "labelsByUnit": { "1": ["accept"], + "2": ["artifact"], + "12": ["artifact"] }, - "mergeGroups": [] + "mergeGroups": [[2,12]] } json_path = "test_data.json" with open(json_path, 'w') as f: json.dump(test_json, f, indent=4) - sorting_curated = apply_sortingview_curation(sorting, uri_or_json=json_path, verbose=True) - accept_idx = np.where(sorting_curated.get_property("accept"))[0] - sorting_curated_ids = sorting_curated.get_unit_ids() + sorting_curated_json = apply_sortingview_curation(sorting, uri_or_json=json_path, verbose=True) + accept_idx = np.where(sorting_curated_json.get_property("accept"))[0] + sorting_curated_ids = sorting_curated_json.get_unit_ids() print(f'Accepted unit IDs: {sorting_curated_ids[accept_idx]}') # Check if unit_id 1 has received the "accept" label. - assert sorting_curated.get_unit_property(unit_id=1, key="accept") - # Check if unit_id "#10" has received the "accept" label. + assert sorting_curated_json.get_unit_property(unit_id=1, key="accept") + # Check if unit_id 10 has received the "accept" label. # If so, test fails since only unit_id 1 received the "accept" label in test_json. - assert not sorting_curated.get_unit_property(unit_id=10, key="accept") + assert not sorting_curated_json.get_unit_property(unit_id=10, key="accept") + print(sorting_curated_json.unit_ids) + # Merging unit_ids of dtype int creates a new unit id + assert 21 in sorting_curated_json.unit_ids if __name__ == "__main__": # generate_sortingview_curation_dataset() From ffaf06756b3884646785fd81bce2d123abaaff0d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 24 Sep 2023 20:09:34 +0000 Subject: [PATCH 15/26] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../curation/sortingview_curation.py | 2 +- .../tests/test_sortingview_curation.py | 33 ++++++++----------- 2 files changed, 15 insertions(+), 20 deletions(-) diff --git a/src/spikeinterface/curation/sortingview_curation.py b/src/spikeinterface/curation/sortingview_curation.py index f595a67a3f..a5633fe165 100644 --- a/src/spikeinterface/curation/sortingview_curation.py +++ b/src/spikeinterface/curation/sortingview_curation.py @@ -83,7 +83,7 @@ def apply_sortingview_curation( properties[label] = np.zeros(len(curation_sorting.current_sorting.unit_ids), dtype=bool) for u_i, unit_id in enumerate(curation_sorting.current_sorting.unit_ids): labels_unit = [] - unit_id_parts = str(unit_id).split('-') + unit_id_parts = str(unit_id).split("-") for unit_label, labels in labels_dict.items(): if unit_label in unit_id_parts: labels_unit.extend(labels) diff --git a/src/spikeinterface/curation/tests/test_sortingview_curation.py b/src/spikeinterface/curation/tests/test_sortingview_curation.py index c8a0788223..a8944f0688 100644 --- a/src/spikeinterface/curation/tests/test_sortingview_curation.py +++ b/src/spikeinterface/curation/tests/test_sortingview_curation.py @@ -132,10 +132,11 @@ def test_json_curation(): assert len(sorting_curated_json_mua.unit_ids) == 6 assert len(sorting_curated_json_mua1.unit_ids) == 5 + def test_false_positive_curation(): # https://spikeinterface.readthedocs.io/en/latest/modules_gallery/core/plot_2_sorting_extractor.html - sampling_frequency = 30000. - duration = 20. + sampling_frequency = 30000.0 + duration = 20.0 num_timepoints = int(sampling_frequency * duration) num_units = 20 num_spikes = 1000 @@ -145,36 +146,30 @@ def test_false_positive_curation(): labels1 = np.random.randint(1, num_units + 1, size=num_spikes) sorting = se.NumpySorting.from_times_labels([times0, times1], [labels0, labels1], sampling_frequency) - print('Sorting: {}'.format(sorting.get_unit_ids())) + print("Sorting: {}".format(sorting.get_unit_ids())) # Test curation JSON: - test_json = { - "labelsByUnit": { - "1": ["accept"], - "2": ["artifact"], - "12": ["artifact"] - }, - "mergeGroups": [[2,12]] - } + test_json = {"labelsByUnit": {"1": ["accept"], "2": ["artifact"], "12": ["artifact"]}, "mergeGroups": [[2, 12]]} json_path = "test_data.json" - with open(json_path, 'w') as f: + with open(json_path, "w") as f: json.dump(test_json, f, indent=4) sorting_curated_json = apply_sortingview_curation(sorting, uri_or_json=json_path, verbose=True) accept_idx = np.where(sorting_curated_json.get_property("accept"))[0] sorting_curated_ids = sorting_curated_json.get_unit_ids() - print(f'Accepted unit IDs: {sorting_curated_ids[accept_idx]}') + print(f"Accepted unit IDs: {sorting_curated_ids[accept_idx]}") - # Check if unit_id 1 has received the "accept" label. - assert sorting_curated_json.get_unit_property(unit_id=1, key="accept") - # Check if unit_id 10 has received the "accept" label. - # If so, test fails since only unit_id 1 received the "accept" label in test_json. - assert not sorting_curated_json.get_unit_property(unit_id=10, key="accept") + # Check if unit_id 1 has received the "accept" label. + assert sorting_curated_json.get_unit_property(unit_id=1, key="accept") + # Check if unit_id 10 has received the "accept" label. + # If so, test fails since only unit_id 1 received the "accept" label in test_json. + assert not sorting_curated_json.get_unit_property(unit_id=10, key="accept") print(sorting_curated_json.unit_ids) - # Merging unit_ids of dtype int creates a new unit id + # Merging unit_ids of dtype int creates a new unit id assert 21 in sorting_curated_json.unit_ids + if __name__ == "__main__": # generate_sortingview_curation_dataset() test_sha1_curation() From 57bb3a734978d207f12733eb4c4807cb8e22c06f Mon Sep 17 00:00:00 2001 From: Windows Home Date: Tue, 26 Sep 2023 22:54:41 -0500 Subject: [PATCH 16/26] Implement more tests to ensure int and string unit IDs merging, inheriting labels, etc. --- .../curation/sortingview_curation.py | 49 +++-- .../tests/test_sortingview_curation.py | 195 +++++++++++++++--- 2 files changed, 202 insertions(+), 42 deletions(-) diff --git a/src/spikeinterface/curation/sortingview_curation.py b/src/spikeinterface/curation/sortingview_curation.py index f595a67a3f..b7f0cab6a0 100644 --- a/src/spikeinterface/curation/sortingview_curation.py +++ b/src/spikeinterface/curation/sortingview_curation.py @@ -57,38 +57,52 @@ def apply_sortingview_curation( unit_ids_dtype = sorting.unit_ids.dtype # STEP 1: merge groups + labels_dict = sortingview_curation_dict["labelsByUnit"] if "mergeGroups" in sortingview_curation_dict and not skip_merge: merge_groups = sortingview_curation_dict["mergeGroups"] - for mg in merge_groups: + for merge_group in merge_groups: + # Store labels of units that are about to be merged + labels_to_inherit = [] + for unit in merge_group: + labels_to_inherit.extend(labels_dict.get(str(unit), [])) + labels_to_inherit = list(set(labels_to_inherit)) # Remove duplicates + if verbose: - print(f"Merging {mg}") + print(f"Merging {merge_group}") if unit_ids_dtype.kind in ("U", "S"): # if unit dtype is str, set new id as "{unit1}-{unit2}" - new_unit_id = "-".join(mg) + new_unit_id = "-".join(merge_group) + curation_sorting.merge(merge_group, new_unit_id=new_unit_id) else: # in this case, the CurationSorting takes care of finding a new unused int - new_unit_id = None - curation_sorting.merge(mg, new_unit_id=new_unit_id) + curation_sorting.merge(merge_group, new_unit_id=None) + new_unit_id = curation_sorting.max_used_id # merged unit id + labels_dict[str(new_unit_id)] = labels_to_inherit # STEP 2: gather and apply sortingview curation labels - # In sortingview, a unit is not required to have all labels. # For example, the first 3 units could be labeled as "accept". # In this case, the first 3 values of the property "accept" will be True, the rest False - labels_dict = sortingview_curation_dict["labelsByUnit"] - properties = {} - for _, labels in labels_dict.items(): - for label in labels: - if label not in properties: - properties[label] = np.zeros(len(curation_sorting.current_sorting.unit_ids), dtype=bool) + + # Initialize the properties dictionary + properties = {label: np.zeros(len(curation_sorting.current_sorting.unit_ids), dtype=bool) + for labels in labels_dict.values() for label in labels} + + # Populate the properties dictionary for u_i, unit_id in enumerate(curation_sorting.current_sorting.unit_ids): - labels_unit = [] - unit_id_parts = str(unit_id).split('-') - for unit_label, labels in labels_dict.items(): - if unit_label in unit_id_parts: - labels_unit.extend(labels) + labels_unit = set() + + # Check for exact match first + if str(unit_id) in labels_dict: + labels_unit.update(labels_dict[str(unit_id)]) + # If no exact match, check if unit_label is a substring of unit_id (for string unit ID merged unit) + else: + for unit_label, labels in labels_dict.items(): + if isinstance(unit_id, str) and unit_label in unit_id: + labels_unit.update(labels) for label in labels_unit: properties[label][u_i] = True + for prop_name, prop_values in properties.items(): curation_sorting.current_sorting.set_property(prop_name, prop_values) @@ -104,5 +118,4 @@ def apply_sortingview_curation( units_to_remove.extend(unit_ids[curation_sorting.current_sorting.get_property(exclude_label) == True]) units_to_remove = np.unique(units_to_remove) curation_sorting.remove_units(units_to_remove) - return curation_sorting.current_sorting diff --git a/src/spikeinterface/curation/tests/test_sortingview_curation.py b/src/spikeinterface/curation/tests/test_sortingview_curation.py index c8a0788223..48923aa15d 100644 --- a/src/spikeinterface/curation/tests/test_sortingview_curation.py +++ b/src/spikeinterface/curation/tests/test_sortingview_curation.py @@ -3,6 +3,7 @@ import os import json import numpy as np + import spikeinterface as si import spikeinterface.extractors as se from spikeinterface.extractors import read_mearec @@ -14,11 +15,11 @@ compute_spike_amplitudes, ) from spikeinterface.curation import apply_sortingview_curation - if hasattr(pytest, "global_test_folder"): cache_folder = pytest.global_test_folder / "curation" else: cache_folder = Path("cache_folder") / "curation" + parent_folder = Path(__file__).parent ON_GITHUB = bool(os.getenv("GITHUB_ACTIONS")) KACHERY_CLOUD_SET = bool(os.getenv("KACHERY_CLOUD_CLIENT_ID")) and bool(os.getenv("KACHERY_CLOUD_PRIVATE_KEY")) @@ -27,6 +28,7 @@ set_global_tmp_folder(cache_folder) + # this needs to be run only once def generate_sortingview_curation_dataset(): import spikeinterface.widgets as sw @@ -50,15 +52,15 @@ def generate_sortingview_curation_dataset(): @pytest.mark.skipif(ON_GITHUB and not KACHERY_CLOUD_SET, reason="Kachery cloud secrets not available") def test_gh_curation(): + """ + Test curation using GitHub URI. + """ local_path = si.download_dataset(remote_path="mearec/mearec_test_10s.h5") _, sorting = read_mearec(local_path) - - # from GH # curated link: # https://figurl.org/f?v=gs://figurl/spikesortingview-10&d=sha1://bd53f6b707f8121cadc901562a89b67aec81cc81&label=SpikeInterface%20-%20Sorting%20Summary&s={%22sortingCuration%22:%22gh://alejoe91/spikeinterface/fix-codecov/spikeinterface/curation/tests/sv-sorting-curation.json%22} gh_uri = "gh://SpikeInterface/spikeinterface/main/src/spikeinterface/curation/tests/sv-sorting-curation.json" sorting_curated_gh = apply_sortingview_curation(sorting, uri_or_json=gh_uri, verbose=True) - print(f"From GH: {sorting_curated_gh}") assert len(sorting_curated_gh.unit_ids) == 9 assert "#8-#9" in sorting_curated_gh.unit_ids @@ -75,9 +77,13 @@ def test_gh_curation(): assert len(sorting_curated_gh_mua.unit_ids) == 6 assert len(sorting_curated_gh_art_mua.unit_ids) == 5 + print("Test for GH passed!\n") @pytest.mark.skipif(ON_GITHUB and not KACHERY_CLOUD_SET, reason="Kachery cloud secrets not available") def test_sha1_curation(): + """ + Test curation using SHA1 URI. + """ local_path = si.download_dataset(remote_path="mearec/mearec_test_10s.h5") _, sorting = read_mearec(local_path) @@ -93,7 +99,7 @@ def test_sha1_curation(): assert "accept" in sorting_curated_sha1.get_property_keys() assert "mua" in sorting_curated_sha1.get_property_keys() assert "artifact" in sorting_curated_sha1.get_property_keys() - + unit_ids = sorting_curated_sha1.unit_ids sorting_curated_sha1_accepted = apply_sortingview_curation(sorting, uri_or_json=sha1_uri, include_labels=["accept"]) sorting_curated_sha1_mua = apply_sortingview_curation(sorting, uri_or_json=sha1_uri, exclude_labels=["mua"]) sorting_curated_sha1_art_mua = apply_sortingview_curation( @@ -103,19 +109,21 @@ def test_sha1_curation(): assert len(sorting_curated_sha1_mua.unit_ids) == 6 assert len(sorting_curated_sha1_art_mua.unit_ids) == 5 + print("Test for sha1 curation passed!\n") def test_json_curation(): + """ + Test curation using a JSON file. + """ local_path = si.download_dataset(remote_path="mearec/mearec_test_10s.h5") _, sorting = read_mearec(local_path) # from curation.json json_file = parent_folder / "sv-sorting-curation.json" - sorting_curated_json = apply_sortingview_curation(sorting, uri_or_json=json_file, verbose=True) print(f"Sorting: {sorting.get_unit_ids()}") - print(f"From JSON: {sorting_curated_json}") + sorting_curated_json = apply_sortingview_curation(sorting, uri_or_json=json_file, verbose=True) assert len(sorting_curated_json.unit_ids) == 9 - print(sorting_curated_json.unit_ids) assert "#8-#9" in sorting_curated_json.unit_ids assert "accept" in sorting_curated_json.get_property_keys() assert "mua" in sorting_curated_json.get_property_keys() @@ -131,20 +139,23 @@ def test_json_curation(): assert len(sorting_curated_json_accepted.unit_ids) == 3 assert len(sorting_curated_json_mua.unit_ids) == 6 assert len(sorting_curated_json_mua1.unit_ids) == 5 + + print("Test for json curation passed!\n") def test_false_positive_curation(): + """ + Test curation for false positives. + """ # https://spikeinterface.readthedocs.io/en/latest/modules_gallery/core/plot_2_sorting_extractor.html sampling_frequency = 30000. duration = 20. num_timepoints = int(sampling_frequency * duration) num_units = 20 num_spikes = 1000 - times0 = np.int_(np.sort(np.random.uniform(0, num_timepoints, num_spikes))) - labels0 = np.random.randint(1, num_units + 1, size=num_spikes) - times1 = np.int_(np.sort(np.random.uniform(0, num_timepoints, num_spikes))) - labels1 = np.random.randint(1, num_units + 1, size=num_spikes) + times = np.int_(np.sort(np.random.uniform(0, num_timepoints, num_spikes))) + labels = np.random.randint(1, num_units + 1, size=num_spikes) - sorting = se.NumpySorting.from_times_labels([times0, times1], [labels0, labels1], sampling_frequency) + sorting = se.NumpySorting.from_times_labels(times, labels, sampling_frequency) print('Sorting: {}'.format(sorting.get_unit_ids())) # Test curation JSON: @@ -161,23 +172,159 @@ def test_false_positive_curation(): with open(json_path, 'w') as f: json.dump(test_json, f, indent=4) + # Apply curation sorting_curated_json = apply_sortingview_curation(sorting, uri_or_json=json_path, verbose=True) - accept_idx = np.where(sorting_curated_json.get_property("accept"))[0] - sorting_curated_ids = sorting_curated_json.get_unit_ids() - print(f'Accepted unit IDs: {sorting_curated_ids[accept_idx]}') - - # Check if unit_id 1 has received the "accept" label. - assert sorting_curated_json.get_unit_property(unit_id=1, key="accept") - # Check if unit_id 10 has received the "accept" label. - # If so, test fails since only unit_id 1 received the "accept" label in test_json. - assert not sorting_curated_json.get_unit_property(unit_id=10, key="accept") - print(sorting_curated_json.unit_ids) - # Merging unit_ids of dtype int creates a new unit id + print('Curated:', sorting_curated_json.get_unit_ids()) + + # Assertions + assert sorting_curated_json.get_unit_property(unit_id=1, key="accept") + assert not sorting_curated_json.get_unit_property(unit_id=10, key="accept") assert 21 in sorting_curated_json.unit_ids + print("False positive test for integer unit IDs passed!\n") + +def test_label_inheritance_int(): + """ + Test curation for label inheritance for integer unit IDs. + """ + # Setup + sampling_frequency = 30000. + duration = 20. + num_timepoints = int(sampling_frequency * duration) + num_spikes = 1000 + times = np.int_(np.sort(np.random.uniform(0, num_timepoints, num_spikes))) + labels = np.random.randint(1, 8, size=num_spikes) # 7 units: 1 to 7 + + sorting = se.NumpySorting.from_times_labels(times, labels, sampling_frequency) + + # Create a curation JSON with labels and merge groups + curation_dict = { + "labelsByUnit": { + "1": ["mua"], + "2": ["mua"], + "3": ["reject"], + "4": ["noise"], + "5": ["accept"], + "6": ["accept"], + "7": ["accept"] + }, + "mergeGroups": [[1, 2], [3, 4], [5, 6]] + } + + json_path = "test_curation_int.json" + with open(json_path, 'w') as f: + json.dump(curation_dict, f, indent=4) + + # Apply curation + sorting_merge = apply_sortingview_curation(sorting, uri_or_json=json_path) + + # Assertions for merged units + print(f"Merge only: {sorting_merge.get_unit_ids()}") + assert sorting_merge.get_unit_property(unit_id=8, key="mua") # 8 = merged unit of 1 and 2 + assert not sorting_merge.get_unit_property(unit_id=8, key="reject") + assert not sorting_merge.get_unit_property(unit_id=8, key="noise") + assert not sorting_merge.get_unit_property(unit_id=8, key="accept") + + assert not sorting_merge.get_unit_property(unit_id=9, key="mua") # 9 = merged unit of 3 and 4 + assert sorting_merge.get_unit_property(unit_id=9, key="reject") + assert sorting_merge.get_unit_property(unit_id=9, key="noise") + assert not sorting_merge.get_unit_property(unit_id=9, key="accept") + + assert not sorting_merge.get_unit_property(unit_id=10, key="mua") # 10 = merged unit of 5 and 6 + assert not sorting_merge.get_unit_property(unit_id=10, key="reject") + assert not sorting_merge.get_unit_property(unit_id=10, key="noise") + assert sorting_merge.get_unit_property(unit_id=10, key="accept") + + # Assertions for exclude_labels + sorting_exclude_noise = apply_sortingview_curation(sorting, uri_or_json=json_path, exclude_labels=["noise"]) + print(f"Exclude noise: {sorting_exclude_noise.get_unit_ids()}") + assert 9 not in sorting_exclude_noise.get_unit_ids() + + # Assertions for include_labels + sorting_include_accept = apply_sortingview_curation(sorting, uri_or_json=json_path, include_labels=["accept"]) + print(f"Include accept: {sorting_include_accept.get_unit_ids()}") + assert 8 not in sorting_include_accept.get_unit_ids() + assert 9 not in sorting_include_accept.get_unit_ids() + assert 10 in sorting_include_accept.get_unit_ids() + + print("Test for integer unit IDs passed!\n") + + +def test_label_inheritance_str(): + """ + Test curation for label inheritance for string unit IDs. + """ + sampling_frequency = 30000. + duration = 20. + num_timepoints = int(sampling_frequency * duration) + num_spikes = 1000 + times = np.int_(np.sort(np.random.uniform(0, num_timepoints, num_spikes))) + labels = np.random.choice(['a', 'b', 'c', 'd', 'e', 'f', 'g'], size=num_spikes) + + sorting = se.NumpySorting.from_times_labels(times, labels, sampling_frequency) + print(f"Sorting: {sorting.get_unit_ids()}") + # Create a curation JSON with labels and merge groups + curation_dict = { + "labelsByUnit": { + "a": ["mua"], + "b": ["mua"], + "c": ["reject"], + "d": ["noise"], + "e": ["accept"], + "f": ["accept"], + "g": ["accept"] + }, + "mergeGroups": [["a", "b"], ["c", "d"], ["e", "f"]] + } + + json_path = "test_curation_str.json" + with open(json_path, 'w') as f: + json.dump(curation_dict, f, indent=4) + + # Check label inheritance for merged units + merged_id_1 = "a-b" + merged_id_2 = "c-d" + merged_id_3 = "e-f" + # Apply curation + sorting_merge = apply_sortingview_curation(sorting, uri_or_json=json_path, verbose=True) + + # Assertions for merged units + print(f"Merge only: {sorting_merge.get_unit_ids()}") + assert sorting_merge.get_unit_property(unit_id="a-b", key="mua") + assert not sorting_merge.get_unit_property(unit_id="a-b", key="reject") + assert not sorting_merge.get_unit_property(unit_id="a-b", key="noise") + assert not sorting_merge.get_unit_property(unit_id="a-b", key="accept") + + assert not sorting_merge.get_unit_property(unit_id="c-d", key="mua") + assert sorting_merge.get_unit_property(unit_id="c-d", key="reject") + assert sorting_merge.get_unit_property(unit_id="c-d", key="noise") + assert not sorting_merge.get_unit_property(unit_id="c-d", key="accept") + + assert not sorting_merge.get_unit_property(unit_id="e-f", key="mua") + assert not sorting_merge.get_unit_property(unit_id="e-f", key="reject") + assert not sorting_merge.get_unit_property(unit_id="e-f", key="noise") + assert sorting_merge.get_unit_property(unit_id="e-f", key="accept") + + # Assertions for exclude_labels + sorting_exclude_noise = apply_sortingview_curation(sorting, uri_or_json=json_path, exclude_labels=["noise"]) + print(f"Exclude noise: {sorting_exclude_noise.get_unit_ids()}") + assert "c-d" not in sorting_exclude_noise.get_unit_ids() + + # Assertions for include_labels + sorting_include_accept = apply_sortingview_curation(sorting, uri_or_json=json_path, include_labels=["accept"]) + print(f"Include accept: {sorting_include_accept.get_unit_ids()}") + assert "a-b" not in sorting_include_accept.get_unit_ids() + assert "c-d" not in sorting_include_accept.get_unit_ids() + assert "e-f" in sorting_include_accept.get_unit_ids() + + print("Test for string unit IDs passed!\n") + + if __name__ == "__main__": # generate_sortingview_curation_dataset() test_sha1_curation() test_gh_curation() test_json_curation() test_false_positive_curation() + test_label_inheritance_int() + test_label_inheritance_str() \ No newline at end of file From a8e07a71d8306550a20a6a611222fb76190d3178 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 27 Sep 2023 04:01:49 +0000 Subject: [PATCH 17/26] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../curation/sortingview_curation.py | 9 ++++-- .../tests/test_sortingview_curation.py | 31 ++++++++++--------- 2 files changed, 23 insertions(+), 17 deletions(-) diff --git a/src/spikeinterface/curation/sortingview_curation.py b/src/spikeinterface/curation/sortingview_curation.py index 7ae8e41030..f83ff3352b 100644 --- a/src/spikeinterface/curation/sortingview_curation.py +++ b/src/spikeinterface/curation/sortingview_curation.py @@ -76,7 +76,7 @@ def apply_sortingview_curation( else: # in this case, the CurationSorting takes care of finding a new unused int curation_sorting.merge(merge_group, new_unit_id=None) - new_unit_id = curation_sorting.max_used_id # merged unit id + new_unit_id = curation_sorting.max_used_id # merged unit id labels_dict[str(new_unit_id)] = labels_to_inherit # STEP 2: gather and apply sortingview curation labels @@ -85,8 +85,11 @@ def apply_sortingview_curation( # In this case, the first 3 values of the property "accept" will be True, the rest False # Initialize the properties dictionary - properties = {label: np.zeros(len(curation_sorting.current_sorting.unit_ids), dtype=bool) - for labels in labels_dict.values() for label in labels} + properties = { + label: np.zeros(len(curation_sorting.current_sorting.unit_ids), dtype=bool) + for labels in labels_dict.values() + for label in labels + } # Populate the properties dictionary for u_i, unit_id in enumerate(curation_sorting.current_sorting.unit_ids): diff --git a/src/spikeinterface/curation/tests/test_sortingview_curation.py b/src/spikeinterface/curation/tests/test_sortingview_curation.py index 958df6acb5..cfc15013a3 100644 --- a/src/spikeinterface/curation/tests/test_sortingview_curation.py +++ b/src/spikeinterface/curation/tests/test_sortingview_curation.py @@ -15,6 +15,7 @@ compute_spike_amplitudes, ) from spikeinterface.curation import apply_sortingview_curation + if hasattr(pytest, "global_test_folder"): cache_folder = pytest.global_test_folder / "curation" else: @@ -28,7 +29,6 @@ set_global_tmp_folder(cache_folder) - # this needs to be run only once def generate_sortingview_curation_dataset(): import spikeinterface.widgets as sw @@ -79,6 +79,7 @@ def test_gh_curation(): print("Test for GH passed!\n") + @pytest.mark.skipif(ON_GITHUB and not KACHERY_CLOUD_SET, reason="Kachery cloud secrets not available") def test_sha1_curation(): """ @@ -111,6 +112,7 @@ def test_sha1_curation(): print("Test for sha1 curation passed!\n") + def test_json_curation(): """ Test curation using a JSON file. @@ -157,7 +159,7 @@ def test_false_positive_curation(): labels = np.random.randint(1, num_units + 1, size=num_spikes) sorting = se.NumpySorting.from_times_labels(times, labels, sampling_frequency) - print('Sorting: {}'.format(sorting.get_unit_ids())) + print("Sorting: {}".format(sorting.get_unit_ids())) # Test curation JSON: test_json = {"labelsByUnit": {"1": ["accept"], "2": ["artifact"], "12": ["artifact"]}, "mergeGroups": [[2, 12]]} @@ -168,7 +170,7 @@ def test_false_positive_curation(): # Apply curation sorting_curated_json = apply_sortingview_curation(sorting, uri_or_json=json_path, verbose=True) - print('Curated:', sorting_curated_json.get_unit_ids()) + print("Curated:", sorting_curated_json.get_unit_ids()) # Assertions assert sorting_curated_json.get_unit_property(unit_id=1, key="accept") @@ -177,13 +179,14 @@ def test_false_positive_curation(): print("False positive test for integer unit IDs passed!\n") + def test_label_inheritance_int(): """ Test curation for label inheritance for integer unit IDs. """ # Setup - sampling_frequency = 30000. - duration = 20. + sampling_frequency = 30000.0 + duration = 20.0 num_timepoints = int(sampling_frequency * duration) num_spikes = 1000 times = np.int_(np.sort(np.random.uniform(0, num_timepoints, num_spikes))) @@ -200,13 +203,13 @@ def test_label_inheritance_int(): "4": ["noise"], "5": ["accept"], "6": ["accept"], - "7": ["accept"] + "7": ["accept"], }, - "mergeGroups": [[1, 2], [3, 4], [5, 6]] + "mergeGroups": [[1, 2], [3, 4], [5, 6]], } json_path = "test_curation_int.json" - with open(json_path, 'w') as f: + with open(json_path, "w") as f: json.dump(curation_dict, f, indent=4) # Apply curation @@ -248,12 +251,12 @@ def test_label_inheritance_str(): """ Test curation for label inheritance for string unit IDs. """ - sampling_frequency = 30000. - duration = 20. + sampling_frequency = 30000.0 + duration = 20.0 num_timepoints = int(sampling_frequency * duration) num_spikes = 1000 times = np.int_(np.sort(np.random.uniform(0, num_timepoints, num_spikes))) - labels = np.random.choice(['a', 'b', 'c', 'd', 'e', 'f', 'g'], size=num_spikes) + labels = np.random.choice(["a", "b", "c", "d", "e", "f", "g"], size=num_spikes) sorting = se.NumpySorting.from_times_labels(times, labels, sampling_frequency) print(f"Sorting: {sorting.get_unit_ids()}") @@ -266,13 +269,13 @@ def test_label_inheritance_str(): "d": ["noise"], "e": ["accept"], "f": ["accept"], - "g": ["accept"] + "g": ["accept"], }, - "mergeGroups": [["a", "b"], ["c", "d"], ["e", "f"]] + "mergeGroups": [["a", "b"], ["c", "d"], ["e", "f"]], } json_path = "test_curation_str.json" - with open(json_path, 'w') as f: + with open(json_path, "w") as f: json.dump(curation_dict, f, indent=4) # Check label inheritance for merged units From 2c015f78e9311e106e9d2fda4e4026a61ca68c5b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 27 Sep 2023 09:28:28 +0000 Subject: [PATCH 18/26] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/widgets/amplitudes.py | 7 +- src/spikeinterface/widgets/base.py | 2 +- src/spikeinterface/widgets/metrics.py | 6 +- src/spikeinterface/widgets/spike_locations.py | 2 +- .../widgets/spikes_on_traces.py | 20 +- src/spikeinterface/widgets/traces.py | 51 ++-- src/spikeinterface/widgets/unit_locations.py | 2 +- src/spikeinterface/widgets/unit_waveforms.py | 8 +- .../widgets/utils_ipywidgets.py | 222 +++++++++--------- 9 files changed, 163 insertions(+), 157 deletions(-) diff --git a/src/spikeinterface/widgets/amplitudes.py b/src/spikeinterface/widgets/amplitudes.py index 5aa090b1b4..6b6496a577 100644 --- a/src/spikeinterface/widgets/amplitudes.py +++ b/src/spikeinterface/widgets/amplitudes.py @@ -174,6 +174,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt + # import ipywidgets.widgets as widgets import ipywidgets.widgets as W from IPython.display import display @@ -210,7 +211,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.unit_selector, self.checkbox_histograms, ], - layout = W.Layout(align_items="center", width="4cm", height="100%"), + layout=W.Layout(align_items="center", width="4cm", height="100%"), ) self.widget = W.AppLayout( @@ -222,8 +223,8 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): # a first update self._full_update_plot() - self.unit_selector.observe(self._update_plot, names='value', type="change") - self.checkbox_histograms.observe(self._full_update_plot, names='value', type="change") + self.unit_selector.observe(self._update_plot, names="value", type="change") + self.checkbox_histograms.observe(self._full_update_plot, names="value", type="change") if backend_kwargs["display"]: display(self.widget) diff --git a/src/spikeinterface/widgets/base.py b/src/spikeinterface/widgets/base.py index 1ff691320a..9fc7b73707 100644 --- a/src/spikeinterface/widgets/base.py +++ b/src/spikeinterface/widgets/base.py @@ -38,7 +38,7 @@ def set_default_plotter_backend(backend): "width_cm": "Width of the figure in cm (default 10)", "height_cm": "Height of the figure in cm (default 6)", "display": "If True, widgets are immediately displayed", - # "controllers": "" + # "controllers": "" }, "ephyviewer": {}, } diff --git a/src/spikeinterface/widgets/metrics.py b/src/spikeinterface/widgets/metrics.py index 604da35e65..c7b701c8b0 100644 --- a/src/spikeinterface/widgets/metrics.py +++ b/src/spikeinterface/widgets/metrics.py @@ -149,8 +149,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): plt.show() self.unit_selector = UnitSelector(data_plot["sorting"].unit_ids) - self.unit_selector.value = [ ] - + self.unit_selector.value = [] self.widget = widgets.AppLayout( center=self.figure.canvas, @@ -161,7 +160,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): # a first update self._update_ipywidget(None) - self.unit_selector.observe(self._update_ipywidget, names='value', type="change") + self.unit_selector.observe(self._update_ipywidget, names="value", type="change") if backend_kwargs["display"]: display(self.widget) @@ -208,7 +207,6 @@ def _update_ipywidget(self, change): self.figure.canvas.draw() self.figure.canvas.flush_events() - def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv from .utils_sortingview import generate_unit_table_view, make_serializable, handle_display_and_url diff --git a/src/spikeinterface/widgets/spike_locations.py b/src/spikeinterface/widgets/spike_locations.py index 926051b8f9..fda2356105 100644 --- a/src/spikeinterface/widgets/spike_locations.py +++ b/src/spikeinterface/widgets/spike_locations.py @@ -222,7 +222,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): # a first update self._update_ipywidget() - self.unit_selector.observe(self._update_ipywidget, names='value', type="change") + self.unit_selector.observe(self._update_ipywidget, names="value", type="change") if backend_kwargs["display"]: display(self.widget) diff --git a/src/spikeinterface/widgets/spikes_on_traces.py b/src/spikeinterface/widgets/spikes_on_traces.py index 2f748cc0fc..c2bed8fe41 100644 --- a/src/spikeinterface/widgets/spikes_on_traces.py +++ b/src/spikeinterface/widgets/spikes_on_traces.py @@ -232,7 +232,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): handles.append(l[0]) labels.append(unit) label_set = True - # ax.legend(handles, labels) + # ax.legend(handles, labels) def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt @@ -268,19 +268,18 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.unit_selector = UnitSelector(data_plot["unit_ids"]) self.unit_selector.value = list(data_plot["unit_ids"])[:1] - self.widget = widgets.AppLayout(center=self._traces_widget.widget, - left_sidebar=self.unit_selector, - pane_widths=ratios + [0]) + self.widget = widgets.AppLayout( + center=self._traces_widget.widget, left_sidebar=self.unit_selector, pane_widths=ratios + [0] + ) # a first update self._update_ipywidget() # remove callback from traces_widget - self.unit_selector.observe(self._update_ipywidget, names='value', type="change") - self._traces_widget.time_slider.observe(self._update_ipywidget, names='value', type="change") - self._traces_widget.channel_selector.observe(self._update_ipywidget, names='value', type="change") - self._traces_widget.scaler.observe(self._update_ipywidget, names='value', type="change") - + self.unit_selector.observe(self._update_ipywidget, names="value", type="change") + self._traces_widget.time_slider.observe(self._update_ipywidget, names="value", type="change") + self._traces_widget.channel_selector.observe(self._update_ipywidget, names="value", type="change") + self._traces_widget.scaler.observe(self._update_ipywidget, names="value", type="change") if backend_kwargs["display"]: display(self.widget) @@ -305,10 +304,9 @@ def _update_ipywidget(self, change=None): time_range=np.array([start_frame, end_frame]) / self.sampling_frequency, mode=mode, with_colorbar=False, - ) + ) ) - backend_kwargs = {} backend_kwargs["ax"] = self.ax diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index d107c5cb23..9b6716e8f3 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -290,7 +290,6 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): check_ipywidget_backend() self.next_data_plot = data_plot.copy() - self.recordings = data_plot["recordings"] @@ -314,7 +313,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.time_slider = TimeSlider( durations=[rec0.get_duration(s) for s in range(rec0.get_num_segments())], sampling_frequency=rec0.sampling_frequency, - # layout=W.Layout(height="2cm"), + # layout=W.Layout(height="2cm"), ) start_frame = int(data_plot["time_range"][0] * rec0.sampling_frequency) @@ -324,14 +323,17 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): _layer_keys = data_plot["layer_keys"] if len(_layer_keys) > 1: - _layer_keys = ['ALL'] + _layer_keys - self.layer_selector = W.Dropdown(options=_layer_keys, - layout=W.Layout(width="95%"), - ) - self.mode_selector = W.Dropdown(options=["line", "map"], value=data_plot["mode"], - # layout=W.Layout(width="5cm"), - layout=W.Layout(width="95%"), - ) + _layer_keys = ["ALL"] + _layer_keys + self.layer_selector = W.Dropdown( + options=_layer_keys, + layout=W.Layout(width="95%"), + ) + self.mode_selector = W.Dropdown( + options=["line", "map"], + value=data_plot["mode"], + # layout=W.Layout(width="5cm"), + layout=W.Layout(width="95%"), + ) self.scaler = ScaleWidget() self.channel_selector = ChannelSelector(self.rec0.channel_ids) @@ -343,9 +345,9 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.mode_selector, self.scaler, # self.channel_selector, - ], + ], layout=W.Layout(width="3.5cm"), - align_items='center', + align_items="center", ) self.return_scaled = data_plot["return_scaled"] @@ -353,7 +355,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.widget = widgets.AppLayout( center=self.figure.canvas, footer=self.time_slider, - left_sidebar = left_sidebar, + left_sidebar=left_sidebar, right_sidebar=self.channel_selector, pane_heights=[0, 6, 1], pane_widths=ratios, @@ -365,28 +367,28 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): # callbacks: # some widgets generate a full retrieve + refresh - self.time_slider.observe(self._retrieve_traces, names='value', type="change") - self.layer_selector.observe(self._retrieve_traces, names='value', type="change") - self.channel_selector.observe(self._retrieve_traces, names='value', type="change") + self.time_slider.observe(self._retrieve_traces, names="value", type="change") + self.layer_selector.observe(self._retrieve_traces, names="value", type="change") + self.channel_selector.observe(self._retrieve_traces, names="value", type="change") # other widgets only refresh - self.scaler.observe(self._update_plot, names='value', type="change") + self.scaler.observe(self._update_plot, names="value", type="change") # map is a special case because needs to check layer also - self.mode_selector.observe(self._mode_changed, names='value', type="change") - + self.mode_selector.observe(self._mode_changed, names="value", type="change") + if backend_kwargs["display"]: # self.check_backend() display(self.widget) def _get_layers(self): layer = self.layer_selector.value - if layer == 'ALL': + if layer == "ALL": layer_keys = self.data_plot["layer_keys"] else: layer_keys = [layer] if self.mode_selector.value == "map": layer_keys = layer_keys[:1] return layer_keys - + def _mode_changed(self, change=None): if self.mode_selector.value == "map" and self.layer_selector.value == "ALL": self.layer_selector.value = self.data_plot["layer_keys"][0] @@ -400,7 +402,7 @@ def _retrieve_traces(self, change=None): order, _ = order_channels_by_depth(self.rec0, channel_ids) else: order = None - + start_frame, end_frame, segment_index = self.time_slider.value time_range = np.array([start_frame, end_frame]) / self.rec0.sampling_frequency @@ -439,9 +441,9 @@ def _update_plot(self, change=None): data_plot["clims"] = clims data_plot["channel_ids"] = self._channel_ids - + data_plot["layer_keys"] = layer_keys - data_plot["colors"] = {k:self.data_plot["colors"][k] for k in layer_keys} + data_plot["colors"] = {k: self.data_plot["colors"][k] for k in layer_keys} list_traces = [traces * self.scaler.value for traces in self._list_traces] data_plot["list_traces"] = list_traces @@ -458,7 +460,6 @@ def _update_plot(self, change=None): fig.canvas.draw() fig.canvas.flush_events() - def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv from .utils_sortingview import handle_display_and_url diff --git a/src/spikeinterface/widgets/unit_locations.py b/src/spikeinterface/widgets/unit_locations.py index 8526a95d60..b41ee3508b 100644 --- a/src/spikeinterface/widgets/unit_locations.py +++ b/src/spikeinterface/widgets/unit_locations.py @@ -198,7 +198,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): # a first update self._update_ipywidget() - self.unit_selector.observe(self._update_ipywidget, names='value', type="change") + self.unit_selector.observe(self._update_ipywidget, names="value", type="change") if backend_kwargs["display"]: display(self.widget) diff --git a/src/spikeinterface/widgets/unit_waveforms.py b/src/spikeinterface/widgets/unit_waveforms.py index f01c842b66..8ffc931bf2 100644 --- a/src/spikeinterface/widgets/unit_waveforms.py +++ b/src/spikeinterface/widgets/unit_waveforms.py @@ -277,7 +277,6 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.unit_selector = UnitSelector(data_plot["unit_ids"]) self.unit_selector.value = list(data_plot["unit_ids"])[:1] - self.same_axis_button = widgets.Checkbox( value=False, description="same axis", @@ -309,10 +308,9 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): # a first update self._update_ipywidget(None) - self.unit_selector.observe(self._update_ipywidget, names='value', type="change") + self.unit_selector.observe(self._update_ipywidget, names="value", type="change") for w in self.same_axis_button, self.plot_templates_button, self.hide_axis_button: - w.observe(self._update_ipywidget, names='value', type="change") - + w.observe(self._update_ipywidget, names="value", type="change") if backend_kwargs["display"]: display(self.widget) @@ -340,7 +338,7 @@ def _update_ipywidget(self, change): data_plot["plot_templates"] = plot_templates if data_plot["plot_waveforms"]: data_plot["wfs_by_ids"] = {unit_id: self.we.get_waveforms(unit_id) for unit_id in unit_ids} - + # TODO option for plot_legend backend_kwargs = {} diff --git a/src/spikeinterface/widgets/utils_ipywidgets.py b/src/spikeinterface/widgets/utils_ipywidgets.py index ee6133a990..6e872eca55 100644 --- a/src/spikeinterface/widgets/utils_ipywidgets.py +++ b/src/spikeinterface/widgets/utils_ipywidgets.py @@ -12,12 +12,9 @@ def check_ipywidget_backend(): class TimeSlider(W.HBox): - value = traitlets.Tuple(traitlets.Int(), traitlets.Int(), traitlets.Int()) - - def __init__(self, durations, sampling_frequency, time_range=(0, 1.), **kwargs): - - + + def __init__(self, durations, sampling_frequency, time_range=(0, 1.0), **kwargs): self.num_segments = len(durations) self.frame_limits = [int(sampling_frequency * d) for d in durations] self.sampling_frequency = sampling_frequency @@ -28,81 +25,100 @@ def __init__(self, durations, sampling_frequency, time_range=(0, 1.), **kwargs): self.segment_index = 0 self.value = (start_frame, end_frame, self.segment_index) - - + layout = W.Layout(align_items="center", width="2.5cm", height="1.cm") - but_left = W.Button(description='', disabled=False, button_style='', icon='arrow-left', layout=layout) - but_right = W.Button(description='', disabled=False, button_style='', icon='arrow-right', layout=layout) - + but_left = W.Button(description="", disabled=False, button_style="", icon="arrow-left", layout=layout) + but_right = W.Button(description="", disabled=False, button_style="", icon="arrow-right", layout=layout) + but_left.on_click(self.move_left) but_right.on_click(self.move_right) - self.move_size = W.Dropdown(options=['10 ms', '100 ms', '1 s', '10 s', '1 m', '30 m', '1 h',], # '6 h', '24 h' - value='1 s', - description='', - layout = W.Layout(width="2cm") - ) + self.move_size = W.Dropdown( + options=[ + "10 ms", + "100 ms", + "1 s", + "10 s", + "1 m", + "30 m", + "1 h", + ], # '6 h', '24 h' + value="1 s", + description="", + layout=W.Layout(width="2cm"), + ) # DatetimePicker is only for ipywidget v8 (which is not working in vscode 2023-03) - self.time_label = W.Text(value=f'{time_range[0]}',description='', - disabled=False, layout=W.Layout(width='2.5cm')) - self.time_label.observe(self.time_label_changed, names='value', type="change") - + self.time_label = W.Text( + value=f"{time_range[0]}", description="", disabled=False, layout=W.Layout(width="2.5cm") + ) + self.time_label.observe(self.time_label_changed, names="value", type="change") self.slider = W.IntSlider( - orientation='horizontal', - # description='time:', + orientation="horizontal", + # description='time:', value=start_frame, min=0, max=self.frame_limits[self.segment_index] - 1, readout=False, continuous_update=False, - layout=W.Layout(width=f'70%') + layout=W.Layout(width=f"70%"), ) - - self.slider.observe(self.slider_moved, names='value', type="change") - + + self.slider.observe(self.slider_moved, names="value", type="change") + delta_s = np.diff(self.frame_range) / sampling_frequency - - self.window_sizer = W.BoundedFloatText(value=delta_s, step=1, - min=0.01, max=30., - description='win (s)', - layout=W.Layout(width='auto') - # layout=W.Layout(width=f'10%') - ) - self.window_sizer.observe(self.win_size_changed, names='value', type="change") + + self.window_sizer = W.BoundedFloatText( + value=delta_s, + step=1, + min=0.01, + max=30.0, + description="win (s)", + layout=W.Layout(width="auto") + # layout=W.Layout(width=f'10%') + ) + self.window_sizer.observe(self.win_size_changed, names="value", type="change") self.segment_selector = W.Dropdown(description="segment", options=list(range(self.num_segments))) - self.segment_selector.observe(self.segment_changed, names='value', type="change") + self.segment_selector.observe(self.segment_changed, names="value", type="change") + + super(W.HBox, self).__init__( + children=[ + self.segment_selector, + but_left, + self.move_size, + but_right, + self.slider, + self.time_label, + self.window_sizer, + ], + layout=W.Layout(align_items="center", width="100%", height="100%"), + **kwargs, + ) - super(W.HBox, self).__init__(children=[self.segment_selector, but_left, self.move_size, but_right, - self.slider, self.time_label, self.window_sizer], - layout=W.Layout(align_items="center", width="100%", height="100%"), - **kwargs) - - self.observe(self.value_changed, names=['value'], type="change") + self.observe(self.value_changed, names=["value"], type="change") def value_changed(self, change=None): - - self.unobserve(self.value_changed, names=['value'], type="change") + self.unobserve(self.value_changed, names=["value"], type="change") start, stop, seg_index = self.value if seg_index < 0 or seg_index >= self.num_segments: - self.value = change['old'] + self.value = change["old"] return if start < 0 or stop < 0: - self.value = change['old'] + self.value = change["old"] return if start >= self.frame_limits[seg_index] or start > self.frame_limits[seg_index]: - self.value = change['old'] + self.value = change["old"] return - + self.segment_selector.value = seg_index self.update_time(new_frame=start, update_slider=True, update_label=True) delta_s = (stop - start) / self.sampling_frequency self.window_sizer.value = delta_s - self.observe(self.value_changed, names=['value'], type="change") + self.observe(self.value_changed, names=["value"], type="change") def update_time(self, new_frame=None, new_time=None, update_slider=False, update_label=False): if new_frame is None and new_time is None: @@ -118,25 +134,24 @@ def update_time(self, new_frame=None, new_time=None, update_slider=False, update start_frame = min(self.frame_limits[self.segment_index] - delta, start_frame) start_frame = max(0, start_frame) end_frame = start_frame + delta - + end_frame = min(self.frame_limits[self.segment_index], end_frame) - start_time = start_frame / self.sampling_frequency if update_label: - self.time_label.unobserve(self.time_label_changed, names='value', type="change") - self.time_label.value = f'{start_time}' - self.time_label.observe(self.time_label_changed, names='value', type="change") + self.time_label.unobserve(self.time_label_changed, names="value", type="change") + self.time_label.value = f"{start_time}" + self.time_label.observe(self.time_label_changed, names="value", type="change") if update_slider: - self.slider.unobserve(self.slider_moved, names='value', type="change") + self.slider.unobserve(self.slider_moved, names="value", type="change") self.slider.value = start_frame - self.slider.observe(self.slider_moved, names='value', type="change") - + self.slider.observe(self.slider_moved, names="value", type="change") + self.frame_range = (start_frame, end_frame) self.value = (start_frame, end_frame, self.segment_index) - + def time_label_changed(self, change=None): try: new_time = float(self.time_label.value) @@ -145,39 +160,39 @@ def time_label_changed(self, change=None): if new_time is not None: self.update_time(new_time=new_time, update_slider=True) - def win_size_changed(self, change=None): self.update_time() - + def slider_moved(self, change=None): new_frame = self.slider.value self.update_time(new_frame=new_frame, update_label=True) - + def move(self, sign): - value, units = self.move_size.value.split(' ') + value, units = self.move_size.value.split(" ") value = int(value) - delta_s = (sign * np.timedelta64(value, units)) / np.timedelta64(1, 's') + delta_s = (sign * np.timedelta64(value, units)) / np.timedelta64(1, "s") delta_sample = int(delta_s * self.sampling_frequency) new_frame = self.frame_range[0] + delta_sample self.slider.value = new_frame - + def move_left(self, change=None): self.move(-1) def move_right(self, change=None): self.move(+1) - + def segment_changed(self, change=None): self.segment_index = self.segment_selector.value - self.slider.unobserve(self.slider_moved, names='value', type="change") + self.slider.unobserve(self.slider_moved, names="value", type="change") # self.slider.value = 0 self.slider.max = self.frame_limits[self.segment_index] - 1 - self.slider.observe(self.slider_moved, names='value', type="change") + self.slider.observe(self.slider_moved, names="value", type="change") self.update_time(new_frame=0, update_slider=True, update_label=True) + class ChannelSelector(W.VBox): value = traitlets.List() @@ -211,22 +226,24 @@ def __init__(self, channel_ids, **kwargs): ) hbox = W.HBox(children=[self.slider, self.selector]) - super(W.VBox, self).__init__(children=[channel_label, hbox], - layout=W.Layout(align_items="center"), - # layout=W.Layout(align_items="center", width="100%", height="100%"), - **kwargs) - self.slider.observe(self.on_slider_changed, names=['value'], type="change") - self.selector.observe(self.on_selector_changed, names=['value'], type="change") + super(W.VBox, self).__init__( + children=[channel_label, hbox], + layout=W.Layout(align_items="center"), + # layout=W.Layout(align_items="center", width="100%", height="100%"), + **kwargs, + ) + self.slider.observe(self.on_slider_changed, names=["value"], type="change") + self.selector.observe(self.on_selector_changed, names=["value"], type="change") # TODO external value change # self.observe(self.value_changed, names=['value'], type="change") - + def on_slider_changed(self, change=None): i0, i1 = self.slider.value - - self.selector.unobserve(self.on_selector_changed, names=['value'], type="change") + + self.selector.unobserve(self.on_selector_changed, names=["value"], type="change") self.selector.value = self.channel_ids[i0:i1][::-1] - self.selector.observe(self.on_selector_changed, names=['value'], type="change") + self.selector.observe(self.on_selector_changed, names=["value"], type="change") self.value = self.channel_ids[i0:i1] @@ -235,27 +252,23 @@ def on_selector_changed(self, change=None): channel_ids = channel_ids[::-1] if len(channel_ids) > 0: - self.slider.unobserve(self.on_slider_changed, names=['value'], type="change") + self.slider.unobserve(self.on_slider_changed, names=["value"], type="change") i0 = self.channel_ids.index(channel_ids[0]) i1 = self.channel_ids.index(channel_ids[-1]) + 1 self.slider.value = (i0, i1) - self.slider.observe(self.on_slider_changed, names=['value'], type="change") + self.slider.observe(self.on_slider_changed, names=["value"], type="change") self.value = channel_ids - class ScaleWidget(W.VBox): value = traitlets.Float() - def __init__(self, value=1., factor=1.2, **kwargs): - - assert factor > 1. + def __init__(self, value=1.0, factor=1.2, **kwargs): + assert factor > 1.0 self.factor = factor - self.scale_label = W.Label("Scale", - layout=W.Layout(layout=W.Layout(width='95%'), - justify_content="center")) + self.scale_label = W.Label("Scale", layout=W.Layout(layout=W.Layout(width="95%"), justify_content="center")) self.plus_selector = W.Button( description="", @@ -264,7 +277,7 @@ def __init__(self, value=1., factor=1.2, **kwargs): tooltip="Increase scale", icon="arrow-up", # layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), - layout=W.Layout(width='60%', align_self='center'), + layout=W.Layout(width="60%", align_self="center"), ) self.minus_selector = W.Button( @@ -274,31 +287,31 @@ def __init__(self, value=1., factor=1.2, **kwargs): tooltip="Decrease scale", icon="arrow-down", # layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), - layout=W.Layout(width='60%', align_self='center'), + layout=W.Layout(width="60%", align_self="center"), ) self.plus_selector.on_click(self.plus_clicked) self.minus_selector.on_click(self.minus_clicked) - self.value = 1. - super(W.VBox, self).__init__(children=[self.plus_selector, self.scale_label, self.minus_selector], - # layout=W.Layout(align_items="center", width="100%", height="100%"), - **kwargs) + self.value = 1.0 + super(W.VBox, self).__init__( + children=[self.plus_selector, self.scale_label, self.minus_selector], + # layout=W.Layout(align_items="center", width="100%", height="100%"), + **kwargs, + ) self.update_label() - self.observe(self.value_changed, names=['value'], type="change") - + self.observe(self.value_changed, names=["value"], type="change") + def update_label(self): self.scale_label.value = f"Scale: {self.value:0.2f}" - def plus_clicked(self, change=None): self.value = self.value * self.factor def minus_clicked(self, change=None): self.value = self.value / self.factor - def value_changed(self, change=None): self.update_label() @@ -319,20 +332,17 @@ def __init__(self, unit_ids, **kwargs): layout=W.Layout(height="100%", width="2cm"), ) - super(W.VBox, self).__init__(children=[label, self.selector], - layout=W.Layout(align_items="center"), - **kwargs) - - self.selector.observe(self.on_selector_changed, names=['value'], type="change") + super(W.VBox, self).__init__(children=[label, self.selector], layout=W.Layout(align_items="center"), **kwargs) + + self.selector.observe(self.on_selector_changed, names=["value"], type="change") + + self.observe(self.value_changed, names=["value"], type="change") - self.observe(self.value_changed, names=['value'], type="change") - def on_selector_changed(self, change=None): unit_ids = self.selector.value self.value = unit_ids - - def value_changed(self, change=None): - self.selector.unobserve(self.on_selector_changed, names=['value'], type="change") - self.selector.value = change['new'] - self.selector.observe(self.on_selector_changed, names=['value'], type="change") + def value_changed(self, change=None): + self.selector.unobserve(self.on_selector_changed, names=["value"], type="change") + self.selector.value = change["new"] + self.selector.observe(self.on_selector_changed, names=["value"], type="change") From 8e4b43a4f67a92a1497eda5d53f2be2e04f7779f Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Wed, 27 Sep 2023 11:37:12 +0200 Subject: [PATCH 19/26] Update src/spikeinterface/postprocessing/amplitude_scalings.py --- src/spikeinterface/postprocessing/amplitude_scalings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 8823fd6257..7e6c95a875 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -431,7 +431,7 @@ def _are_unit_indices_overlapping(sparsity_mask, i, j): bool True if the unit indices i and j are overlapping, False otherwise """ - if np.sum(np.logical_and(sparsity_mask[i], sparsity_mask[j])) > 0: + if np.any(sparsity_mask[i] & sparsity_mask[j]): return True else: return False From fb82e029be652fa33b69367d9d97f9c7a465914e Mon Sep 17 00:00:00 2001 From: Robin Kim <31869753+rkim48@users.noreply.github.com> Date: Wed, 27 Sep 2023 10:16:37 -0500 Subject: [PATCH 20/26] Apply suggestions from code review Remove print('success') statements Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- .../curation/tests/test_sortingview_curation.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/spikeinterface/curation/tests/test_sortingview_curation.py b/src/spikeinterface/curation/tests/test_sortingview_curation.py index cfc15013a3..79cea3d010 100644 --- a/src/spikeinterface/curation/tests/test_sortingview_curation.py +++ b/src/spikeinterface/curation/tests/test_sortingview_curation.py @@ -77,7 +77,6 @@ def test_gh_curation(): assert len(sorting_curated_gh_mua.unit_ids) == 6 assert len(sorting_curated_gh_art_mua.unit_ids) == 5 - print("Test for GH passed!\n") @pytest.mark.skipif(ON_GITHUB and not KACHERY_CLOUD_SET, reason="Kachery cloud secrets not available") @@ -110,7 +109,6 @@ def test_sha1_curation(): assert len(sorting_curated_sha1_mua.unit_ids) == 6 assert len(sorting_curated_sha1_art_mua.unit_ids) == 5 - print("Test for sha1 curation passed!\n") def test_json_curation(): @@ -244,7 +242,6 @@ def test_label_inheritance_int(): assert 9 not in sorting_include_accept.get_unit_ids() assert 10 in sorting_include_accept.get_unit_ids() - print("Test for integer unit IDs passed!\n") def test_label_inheritance_str(): @@ -314,7 +311,6 @@ def test_label_inheritance_str(): assert "c-d" not in sorting_include_accept.get_unit_ids() assert "e-f" in sorting_include_accept.get_unit_ids() - print("Test for string unit IDs passed!\n") if __name__ == "__main__": From 776520bb100986bd90653d9b8eeba77eb0cc16aa Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 27 Sep 2023 15:16:55 +0000 Subject: [PATCH 21/26] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../curation/tests/test_sortingview_curation.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/spikeinterface/curation/tests/test_sortingview_curation.py b/src/spikeinterface/curation/tests/test_sortingview_curation.py index 79cea3d010..71912d7793 100644 --- a/src/spikeinterface/curation/tests/test_sortingview_curation.py +++ b/src/spikeinterface/curation/tests/test_sortingview_curation.py @@ -78,7 +78,6 @@ def test_gh_curation(): assert len(sorting_curated_gh_art_mua.unit_ids) == 5 - @pytest.mark.skipif(ON_GITHUB and not KACHERY_CLOUD_SET, reason="Kachery cloud secrets not available") def test_sha1_curation(): """ @@ -110,7 +109,6 @@ def test_sha1_curation(): assert len(sorting_curated_sha1_art_mua.unit_ids) == 5 - def test_json_curation(): """ Test curation using a JSON file. @@ -243,7 +241,6 @@ def test_label_inheritance_int(): assert 10 in sorting_include_accept.get_unit_ids() - def test_label_inheritance_str(): """ Test curation for label inheritance for string unit IDs. @@ -312,7 +309,6 @@ def test_label_inheritance_str(): assert "e-f" in sorting_include_accept.get_unit_ids() - if __name__ == "__main__": # generate_sortingview_curation_dataset() test_sha1_curation() From a85b4a8d666311325e74feaf05e47656048355ea Mon Sep 17 00:00:00 2001 From: Windows Home Date: Thu, 28 Sep 2023 09:39:22 -0500 Subject: [PATCH 22/26] Simplify label assignment logic and add test.json files to tests directory --- .../curation/sortingview_curation.py | 19 ++--- .../sv-sorting-curation-false-positive.json | 19 +++++ .../tests/sv-sorting-curation-int.json | 39 ++++++++++ .../tests/sv-sorting-curation-str.json | 39 ++++++++++ .../tests/test_sortingview_curation.py | 71 +++---------------- 5 files changed, 114 insertions(+), 73 deletions(-) create mode 100644 src/spikeinterface/curation/tests/sv-sorting-curation-false-positive.json create mode 100644 src/spikeinterface/curation/tests/sv-sorting-curation-int.json create mode 100644 src/spikeinterface/curation/tests/sv-sorting-curation-str.json diff --git a/src/spikeinterface/curation/sortingview_curation.py b/src/spikeinterface/curation/sortingview_curation.py index f83ff3352b..7a573c38c4 100644 --- a/src/spikeinterface/curation/sortingview_curation.py +++ b/src/spikeinterface/curation/sortingview_curation.py @@ -77,7 +77,7 @@ def apply_sortingview_curation( # in this case, the CurationSorting takes care of finding a new unused int curation_sorting.merge(merge_group, new_unit_id=None) new_unit_id = curation_sorting.max_used_id # merged unit id - labels_dict[str(new_unit_id)] = labels_to_inherit + labels_dict[str(new_unit_id)] = labels_to_inherit # STEP 2: gather and apply sortingview curation labels # In sortingview, a unit is not required to have all labels. @@ -92,19 +92,12 @@ def apply_sortingview_curation( } # Populate the properties dictionary - for u_i, unit_id in enumerate(curation_sorting.current_sorting.unit_ids): - labels_unit = set() - + for unit_index, unit_id in enumerate(curation_sorting.current_sorting.unit_ids): + unit_id_str = str(unit_id) # Check for exact match first - if str(unit_id) in labels_dict: - labels_unit.update(labels_dict[str(unit_id)]) - # If no exact match, check if unit_label is a substring of unit_id (for string unit ID merged unit) - else: - for unit_label, labels in labels_dict.items(): - if isinstance(unit_id, str) and unit_label in unit_id: - labels_unit.update(labels) - for label in labels_unit: - properties[label][u_i] = True + if unit_id_str in labels_dict: + for label in labels_dict[unit_id_str]: + properties[label][unit_index] = True for prop_name, prop_values in properties.items(): curation_sorting.current_sorting.set_property(prop_name, prop_values) diff --git a/src/spikeinterface/curation/tests/sv-sorting-curation-false-positive.json b/src/spikeinterface/curation/tests/sv-sorting-curation-false-positive.json new file mode 100644 index 0000000000..5c29328363 --- /dev/null +++ b/src/spikeinterface/curation/tests/sv-sorting-curation-false-positive.json @@ -0,0 +1,19 @@ +{ + "labelsByUnit": { + "1": [ + "accept" + ], + "2": [ + "artifact" + ], + "12": [ + "artifact" + ] + }, + "mergeGroups": [ + [ + 2, + 12 + ] + ] +} \ No newline at end of file diff --git a/src/spikeinterface/curation/tests/sv-sorting-curation-int.json b/src/spikeinterface/curation/tests/sv-sorting-curation-int.json new file mode 100644 index 0000000000..486a51a583 --- /dev/null +++ b/src/spikeinterface/curation/tests/sv-sorting-curation-int.json @@ -0,0 +1,39 @@ +{ + "labelsByUnit": { + "1": [ + "mua" + ], + "2": [ + "mua" + ], + "3": [ + "reject" + ], + "4": [ + "noise" + ], + "5": [ + "accept" + ], + "6": [ + "accept" + ], + "7": [ + "accept" + ] + }, + "mergeGroups": [ + [ + 1, + 2 + ], + [ + 3, + 4 + ], + [ + 5, + 6 + ] + ] +} \ No newline at end of file diff --git a/src/spikeinterface/curation/tests/sv-sorting-curation-str.json b/src/spikeinterface/curation/tests/sv-sorting-curation-str.json new file mode 100644 index 0000000000..b2ab1d5849 --- /dev/null +++ b/src/spikeinterface/curation/tests/sv-sorting-curation-str.json @@ -0,0 +1,39 @@ +{ + "labelsByUnit": { + "a": [ + "mua" + ], + "b": [ + "mua" + ], + "c": [ + "reject" + ], + "d": [ + "noise" + ], + "e": [ + "accept" + ], + "f": [ + "accept" + ], + "g": [ + "accept" + ] + }, + "mergeGroups": [ + [ + "a", + "b" + ], + [ + "c", + "d" + ], + [ + "e", + "f" + ] + ] +} \ No newline at end of file diff --git a/src/spikeinterface/curation/tests/test_sortingview_curation.py b/src/spikeinterface/curation/tests/test_sortingview_curation.py index 71912d7793..1579c9f03b 100644 --- a/src/spikeinterface/curation/tests/test_sortingview_curation.py +++ b/src/spikeinterface/curation/tests/test_sortingview_curation.py @@ -138,8 +138,6 @@ def test_json_curation(): assert len(sorting_curated_json_mua.unit_ids) == 6 assert len(sorting_curated_json_mua1.unit_ids) == 5 - print("Test for json curation passed!\n") - def test_false_positive_curation(): """ @@ -157,15 +155,8 @@ def test_false_positive_curation(): sorting = se.NumpySorting.from_times_labels(times, labels, sampling_frequency) print("Sorting: {}".format(sorting.get_unit_ids())) - # Test curation JSON: - test_json = {"labelsByUnit": {"1": ["accept"], "2": ["artifact"], "12": ["artifact"]}, "mergeGroups": [[2, 12]]} - - json_path = "test_data.json" - with open(json_path, "w") as f: - json.dump(test_json, f, indent=4) - - # Apply curation - sorting_curated_json = apply_sortingview_curation(sorting, uri_or_json=json_path, verbose=True) + json_file = parent_folder / "sv-sorting-curation-false-positive.json" + sorting_curated_json = apply_sortingview_curation(sorting, uri_or_json=json_file, verbose=True) print("Curated:", sorting_curated_json.get_unit_ids()) # Assertions @@ -173,8 +164,6 @@ def test_false_positive_curation(): assert not sorting_curated_json.get_unit_property(unit_id=10, key="accept") assert 21 in sorting_curated_json.unit_ids - print("False positive test for integer unit IDs passed!\n") - def test_label_inheritance_int(): """ @@ -190,26 +179,8 @@ def test_label_inheritance_int(): sorting = se.NumpySorting.from_times_labels(times, labels, sampling_frequency) - # Create a curation JSON with labels and merge groups - curation_dict = { - "labelsByUnit": { - "1": ["mua"], - "2": ["mua"], - "3": ["reject"], - "4": ["noise"], - "5": ["accept"], - "6": ["accept"], - "7": ["accept"], - }, - "mergeGroups": [[1, 2], [3, 4], [5, 6]], - } - - json_path = "test_curation_int.json" - with open(json_path, "w") as f: - json.dump(curation_dict, f, indent=4) - - # Apply curation - sorting_merge = apply_sortingview_curation(sorting, uri_or_json=json_path) + json_file = parent_folder / "sv-sorting-curation-int.json" + sorting_merge = apply_sortingview_curation(sorting, uri_or_json=json_file) # Assertions for merged units print(f"Merge only: {sorting_merge.get_unit_ids()}") @@ -229,12 +200,12 @@ def test_label_inheritance_int(): assert sorting_merge.get_unit_property(unit_id=10, key="accept") # Assertions for exclude_labels - sorting_exclude_noise = apply_sortingview_curation(sorting, uri_or_json=json_path, exclude_labels=["noise"]) + sorting_exclude_noise = apply_sortingview_curation(sorting, uri_or_json=json_file, exclude_labels=["noise"]) print(f"Exclude noise: {sorting_exclude_noise.get_unit_ids()}") assert 9 not in sorting_exclude_noise.get_unit_ids() # Assertions for include_labels - sorting_include_accept = apply_sortingview_curation(sorting, uri_or_json=json_path, include_labels=["accept"]) + sorting_include_accept = apply_sortingview_curation(sorting, uri_or_json=json_file, include_labels=["accept"]) print(f"Include accept: {sorting_include_accept.get_unit_ids()}") assert 8 not in sorting_include_accept.get_unit_ids() assert 9 not in sorting_include_accept.get_unit_ids() @@ -254,30 +225,10 @@ def test_label_inheritance_str(): sorting = se.NumpySorting.from_times_labels(times, labels, sampling_frequency) print(f"Sorting: {sorting.get_unit_ids()}") - # Create a curation JSON with labels and merge groups - curation_dict = { - "labelsByUnit": { - "a": ["mua"], - "b": ["mua"], - "c": ["reject"], - "d": ["noise"], - "e": ["accept"], - "f": ["accept"], - "g": ["accept"], - }, - "mergeGroups": [["a", "b"], ["c", "d"], ["e", "f"]], - } - - json_path = "test_curation_str.json" - with open(json_path, "w") as f: - json.dump(curation_dict, f, indent=4) - - # Check label inheritance for merged units - merged_id_1 = "a-b" - merged_id_2 = "c-d" - merged_id_3 = "e-f" + # Apply curation - sorting_merge = apply_sortingview_curation(sorting, uri_or_json=json_path, verbose=True) + json_file = parent_folder / "sv-sorting-curation-str.json" + sorting_merge = apply_sortingview_curation(sorting, uri_or_json=json_file, verbose=True) # Assertions for merged units print(f"Merge only: {sorting_merge.get_unit_ids()}") @@ -297,12 +248,12 @@ def test_label_inheritance_str(): assert sorting_merge.get_unit_property(unit_id="e-f", key="accept") # Assertions for exclude_labels - sorting_exclude_noise = apply_sortingview_curation(sorting, uri_or_json=json_path, exclude_labels=["noise"]) + sorting_exclude_noise = apply_sortingview_curation(sorting, uri_or_json=json_file, exclude_labels=["noise"]) print(f"Exclude noise: {sorting_exclude_noise.get_unit_ids()}") assert "c-d" not in sorting_exclude_noise.get_unit_ids() # Assertions for include_labels - sorting_include_accept = apply_sortingview_curation(sorting, uri_or_json=json_path, include_labels=["accept"]) + sorting_include_accept = apply_sortingview_curation(sorting, uri_or_json=json_file, include_labels=["accept"]) print(f"Include accept: {sorting_include_accept.get_unit_ids()}") assert "a-b" not in sorting_include_accept.get_unit_ids() assert "c-d" not in sorting_include_accept.get_unit_ids() From 54d40eb2a0cc4468100fd8a058cb8a6b8354fd67 Mon Sep 17 00:00:00 2001 From: Windows Home Date: Thu, 28 Sep 2023 09:52:29 -0500 Subject: [PATCH 23/26] Comment out print statements --- .../tests/test_sortingview_curation.py | 23 +++++++++---------- 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/src/spikeinterface/curation/tests/test_sortingview_curation.py b/src/spikeinterface/curation/tests/test_sortingview_curation.py index 1579c9f03b..a620cb8db1 100644 --- a/src/spikeinterface/curation/tests/test_sortingview_curation.py +++ b/src/spikeinterface/curation/tests/test_sortingview_curation.py @@ -91,7 +91,7 @@ def test_sha1_curation(): # https://figurl.org/f?v=gs://figurl/spikesortingview-10&d=sha1://bd53f6b707f8121cadc901562a89b67aec81cc81&label=SpikeInterface%20-%20Sorting%20Summary&s={%22sortingCuration%22:%22sha1://1182ba19671fcc7d3f8e0501b0f8c07fb9736c22%22} sha1_uri = "sha1://1182ba19671fcc7d3f8e0501b0f8c07fb9736c22" sorting_curated_sha1 = apply_sortingview_curation(sorting, uri_or_json=sha1_uri, verbose=True) - print(f"From SHA: {sorting_curated_sha1}") + # print(f"From SHA: {sorting_curated_sha1}") assert len(sorting_curated_sha1.unit_ids) == 9 assert "#8-#9" in sorting_curated_sha1.unit_ids @@ -118,7 +118,7 @@ def test_json_curation(): # from curation.json json_file = parent_folder / "sv-sorting-curation.json" - print(f"Sorting: {sorting.get_unit_ids()}") + # print(f"Sorting: {sorting.get_unit_ids()}") sorting_curated_json = apply_sortingview_curation(sorting, uri_or_json=json_file, verbose=True) assert len(sorting_curated_json.unit_ids) == 9 @@ -153,11 +153,11 @@ def test_false_positive_curation(): labels = np.random.randint(1, num_units + 1, size=num_spikes) sorting = se.NumpySorting.from_times_labels(times, labels, sampling_frequency) - print("Sorting: {}".format(sorting.get_unit_ids())) + # print("Sorting: {}".format(sorting.get_unit_ids())) json_file = parent_folder / "sv-sorting-curation-false-positive.json" sorting_curated_json = apply_sortingview_curation(sorting, uri_or_json=json_file, verbose=True) - print("Curated:", sorting_curated_json.get_unit_ids()) + # print("Curated:", sorting_curated_json.get_unit_ids()) # Assertions assert sorting_curated_json.get_unit_property(unit_id=1, key="accept") @@ -183,7 +183,7 @@ def test_label_inheritance_int(): sorting_merge = apply_sortingview_curation(sorting, uri_or_json=json_file) # Assertions for merged units - print(f"Merge only: {sorting_merge.get_unit_ids()}") + # print(f"Merge only: {sorting_merge.get_unit_ids()}") assert sorting_merge.get_unit_property(unit_id=8, key="mua") # 8 = merged unit of 1 and 2 assert not sorting_merge.get_unit_property(unit_id=8, key="reject") assert not sorting_merge.get_unit_property(unit_id=8, key="noise") @@ -201,12 +201,12 @@ def test_label_inheritance_int(): # Assertions for exclude_labels sorting_exclude_noise = apply_sortingview_curation(sorting, uri_or_json=json_file, exclude_labels=["noise"]) - print(f"Exclude noise: {sorting_exclude_noise.get_unit_ids()}") + # print(f"Exclude noise: {sorting_exclude_noise.get_unit_ids()}") assert 9 not in sorting_exclude_noise.get_unit_ids() # Assertions for include_labels sorting_include_accept = apply_sortingview_curation(sorting, uri_or_json=json_file, include_labels=["accept"]) - print(f"Include accept: {sorting_include_accept.get_unit_ids()}") + # print(f"Include accept: {sorting_include_accept.get_unit_ids()}") assert 8 not in sorting_include_accept.get_unit_ids() assert 9 not in sorting_include_accept.get_unit_ids() assert 10 in sorting_include_accept.get_unit_ids() @@ -224,14 +224,14 @@ def test_label_inheritance_str(): labels = np.random.choice(["a", "b", "c", "d", "e", "f", "g"], size=num_spikes) sorting = se.NumpySorting.from_times_labels(times, labels, sampling_frequency) - print(f"Sorting: {sorting.get_unit_ids()}") + # print(f"Sorting: {sorting.get_unit_ids()}") # Apply curation json_file = parent_folder / "sv-sorting-curation-str.json" sorting_merge = apply_sortingview_curation(sorting, uri_or_json=json_file, verbose=True) # Assertions for merged units - print(f"Merge only: {sorting_merge.get_unit_ids()}") + # print(f"Merge only: {sorting_merge.get_unit_ids()}") assert sorting_merge.get_unit_property(unit_id="a-b", key="mua") assert not sorting_merge.get_unit_property(unit_id="a-b", key="reject") assert not sorting_merge.get_unit_property(unit_id="a-b", key="noise") @@ -249,17 +249,16 @@ def test_label_inheritance_str(): # Assertions for exclude_labels sorting_exclude_noise = apply_sortingview_curation(sorting, uri_or_json=json_file, exclude_labels=["noise"]) - print(f"Exclude noise: {sorting_exclude_noise.get_unit_ids()}") + # print(f"Exclude noise: {sorting_exclude_noise.get_unit_ids()}") assert "c-d" not in sorting_exclude_noise.get_unit_ids() # Assertions for include_labels sorting_include_accept = apply_sortingview_curation(sorting, uri_or_json=json_file, include_labels=["accept"]) - print(f"Include accept: {sorting_include_accept.get_unit_ids()}") + # print(f"Include accept: {sorting_include_accept.get_unit_ids()}") assert "a-b" not in sorting_include_accept.get_unit_ids() assert "c-d" not in sorting_include_accept.get_unit_ids() assert "e-f" in sorting_include_accept.get_unit_ids() - if __name__ == "__main__": # generate_sortingview_curation_dataset() test_sha1_curation() From f1b7bfe668ac8ff0581f252241edfb004577551d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 28 Sep 2023 14:53:07 +0000 Subject: [PATCH 24/26] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../curation/tests/sv-sorting-curation-false-positive.json | 2 +- src/spikeinterface/curation/tests/sv-sorting-curation-int.json | 2 +- src/spikeinterface/curation/tests/sv-sorting-curation-str.json | 2 +- src/spikeinterface/curation/tests/test_sortingview_curation.py | 1 + 4 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/curation/tests/sv-sorting-curation-false-positive.json b/src/spikeinterface/curation/tests/sv-sorting-curation-false-positive.json index 5c29328363..48881388bb 100644 --- a/src/spikeinterface/curation/tests/sv-sorting-curation-false-positive.json +++ b/src/spikeinterface/curation/tests/sv-sorting-curation-false-positive.json @@ -16,4 +16,4 @@ 12 ] ] -} \ No newline at end of file +} diff --git a/src/spikeinterface/curation/tests/sv-sorting-curation-int.json b/src/spikeinterface/curation/tests/sv-sorting-curation-int.json index 486a51a583..2047c514ce 100644 --- a/src/spikeinterface/curation/tests/sv-sorting-curation-int.json +++ b/src/spikeinterface/curation/tests/sv-sorting-curation-int.json @@ -36,4 +36,4 @@ 6 ] ] -} \ No newline at end of file +} diff --git a/src/spikeinterface/curation/tests/sv-sorting-curation-str.json b/src/spikeinterface/curation/tests/sv-sorting-curation-str.json index b2ab1d5849..2585b5cc50 100644 --- a/src/spikeinterface/curation/tests/sv-sorting-curation-str.json +++ b/src/spikeinterface/curation/tests/sv-sorting-curation-str.json @@ -36,4 +36,4 @@ "f" ] ] -} \ No newline at end of file +} diff --git a/src/spikeinterface/curation/tests/test_sortingview_curation.py b/src/spikeinterface/curation/tests/test_sortingview_curation.py index a620cb8db1..22085f2f77 100644 --- a/src/spikeinterface/curation/tests/test_sortingview_curation.py +++ b/src/spikeinterface/curation/tests/test_sortingview_curation.py @@ -259,6 +259,7 @@ def test_label_inheritance_str(): assert "c-d" not in sorting_include_accept.get_unit_ids() assert "e-f" in sorting_include_accept.get_unit_ids() + if __name__ == "__main__": # generate_sortingview_curation_dataset() test_sha1_curation() From f76e9d895a321eceb8dd6e01f0e3fe769867ec16 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 2 Oct 2023 10:14:50 +0200 Subject: [PATCH 25/26] Update src/spikeinterface/curation/sortingview_curation.py --- src/spikeinterface/curation/sortingview_curation.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/curation/sortingview_curation.py b/src/spikeinterface/curation/sortingview_curation.py index 7a573c38c4..626ea79eb9 100644 --- a/src/spikeinterface/curation/sortingview_curation.py +++ b/src/spikeinterface/curation/sortingview_curation.py @@ -94,7 +94,6 @@ def apply_sortingview_curation( # Populate the properties dictionary for unit_index, unit_id in enumerate(curation_sorting.current_sorting.unit_ids): unit_id_str = str(unit_id) - # Check for exact match first if unit_id_str in labels_dict: for label in labels_dict[unit_id_str]: properties[label][unit_index] = True From c20ffdadb908d601e546323b113e994445546891 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 2 Oct 2023 10:23:47 +0200 Subject: [PATCH 26/26] Tiny rewrite in tests --- src/spikeinterface/curation/tests/test_sortingview_curation.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/curation/tests/test_sortingview_curation.py b/src/spikeinterface/curation/tests/test_sortingview_curation.py index 22085f2f77..ce6c7dd5a6 100644 --- a/src/spikeinterface/curation/tests/test_sortingview_curation.py +++ b/src/spikeinterface/curation/tests/test_sortingview_curation.py @@ -174,8 +174,9 @@ def test_label_inheritance_int(): duration = 20.0 num_timepoints = int(sampling_frequency * duration) num_spikes = 1000 + num_units = 7 times = np.int_(np.sort(np.random.uniform(0, num_timepoints, num_spikes))) - labels = np.random.randint(1, 8, size=num_spikes) # 7 units: 1 to 7 + labels = np.random.randint(1, 1 + num_units, size=num_spikes) # 7 units: 1 to 7 sorting = se.NumpySorting.from_times_labels(times, labels, sampling_frequency)