diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index e9726a16da..d3f875959e 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -602,6 +602,15 @@ def compute_firing_ranges(waveform_extractor, bin_size_s=5, percentiles=(5, 95), if unit_ids is None: unit_ids = sorting.unit_ids + if all( + [ + waveform_extractor.get_num_samples(segment_index) < bin_size_samples + for segment_index in range(waveform_extractor.get_num_segments()) + ] + ): + warnings.warn(f"Bin size of {bin_size_s}s is larger than each segment duration. Firing ranges are set to NaN.") + return {unit_id: np.nan for unit_id in unit_ids} + # for each segment, we compute the firing rate histogram and we concatenate them firing_rate_histograms = {unit_id: np.array([], dtype=float) for unit_id in sorting.unit_ids} for segment_index in range(waveform_extractor.get_num_segments()): diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index 2d63a06b17..8a32c4cee8 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -220,6 +220,10 @@ def test_calculate_firing_range(waveform_extractor_simple): firing_ranges = compute_firing_ranges(we) print(firing_ranges) + with pytest.warns(UserWarning) as w: + firing_ranges_nan = compute_firing_ranges(we, bin_size_s=we.get_total_duration() + 1) + assert np.all([np.isnan(f) for f in firing_ranges_nan.values()]) + def test_calculate_amplitude_cutoff(waveform_extractor_simple): we = waveform_extractor_simple @@ -378,7 +382,7 @@ def test_calculate_drift_metrics(waveform_extractor_simple): if __name__ == "__main__": sim_data = _simulated_data() we = _waveform_extractor_simple() - we_violations = _waveform_extractor_violations(sim_data) + # we_violations = _waveform_extractor_violations(sim_data) # test_calculate_amplitude_cutoff(we) # test_calculate_presence_ratio(we) # test_calculate_amplitude_median(we) @@ -387,4 +391,4 @@ def test_calculate_drift_metrics(waveform_extractor_simple): # test_calculate_drift_metrics(we) # test_synchrony_metrics(we) test_calculate_firing_range(we) - test_calculate_amplitude_cv_metrics(we) + # test_calculate_amplitude_cv_metrics(we) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py index c505676c05..abf40b2da6 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py @@ -584,13 +584,13 @@ def plot_motions_several_benchmarks(benchmarks): _simpleaxis(ax) -def plot_speed_several_benchmarks(benchmarks, detailed=True, ax=None, colors=None): +def plot_speed_several_benchmarks(benchmarks, detailed=True, ax=None, colors=None): if ax is None: fig, ax = plt.subplots(figsize=(5, 5)) for count, benchmark in enumerate(benchmarks): color = colors[count] if colors is not None else None - + if detailed: bottom = 0 i = 0 @@ -606,8 +606,6 @@ def plot_speed_several_benchmarks(benchmarks, detailed=True, ax=None, colors=No else: total_run_time = np.sum([value for key, value in benchmark.run_times.items()]) ax.bar([count], [total_run_time], color=color, edgecolor="black") - - # ax.legend() ax.set_ylabel("speed (s)") diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py index 8e5afb2e8e..b28b29f17c 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py @@ -193,11 +193,15 @@ def run_sorters(self, skip_already_done=True): recording = self.recordings[case["recording"]] output_folder = self.folder / f"tmp_sortings_{label}" if output_folder.exists() and skip_already_done: - print('already done') + print("already done") sorting = read_sorter_folder(output_folder) else: sorting = run_sorter( - sorter_name, recording, output_folder, **sorter_params, delete_output_folder=self.delete_output_folder + sorter_name, + recording, + output_folder, + **sorter_params, + delete_output_folder=self.delete_output_folder, ) self.sortings[label] = sorting