diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py index dd35670abd..c505676c05 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py @@ -487,7 +487,7 @@ def plot_errors_several_benchmarks(benchmarks, axes=None, show_legend=True, colo mean_error = np.sqrt(np.mean((errors) ** 2, axis=1)) depth_error = np.sqrt(np.mean((errors) ** 2, axis=0)) - axes[0].plot(benchmark.temporal_bins, mean_error, label=benchmark.title, color=c) + axes[0].plot(benchmark.temporal_bins, mean_error, lw=1, label=benchmark.title, color=c) parts = axes[1].violinplot(mean_error, [count], showmeans=True) if c is not None: for pc in parts["bodies"]: @@ -500,8 +500,8 @@ def plot_errors_several_benchmarks(benchmarks, axes=None, show_legend=True, colo axes[2].plot(benchmark.spatial_bins, depth_error, label=benchmark.title, color=c) ax0 = ax = axes[0] - ax.set_xlabel("time [s]") - ax.set_ylabel("error [um]") + ax.set_xlabel("Time [s]") + ax.set_ylabel("Error [μm]") if show_legend: ax.legend() _simpleaxis(ax) @@ -514,7 +514,7 @@ def plot_errors_several_benchmarks(benchmarks, axes=None, show_legend=True, colo ax2 = axes[2] ax2.set_yticks([]) - ax2.set_xlabel("depth [um]") + ax2.set_xlabel("Depth [μm]") # ax.set_ylabel('error') channel_positions = benchmark.recording.get_channel_locations() probe_y_min, probe_y_max = channel_positions[:, 1].min(), channel_positions[:, 1].max() @@ -584,23 +584,30 @@ def plot_motions_several_benchmarks(benchmarks): _simpleaxis(ax) -def plot_speed_several_benchmarks(benchmarks, ax=None, colors=None): +def plot_speed_several_benchmarks(benchmarks, detailed=True, ax=None, colors=None): if ax is None: fig, ax = plt.subplots(figsize=(5, 5)) for count, benchmark in enumerate(benchmarks): color = colors[count] if colors is not None else None - bottom = 0 - i = 0 - patterns = ["/", "\\", "|", "*"] - for key, value in benchmark.run_times.items(): - if count == 0: - label = key.replace("_", " ") - else: - label = None - ax.bar([count], [value], label=label, bottom=bottom, color=color, edgecolor="black", hatch=patterns[i]) - bottom += value - i += 1 + + if detailed: + bottom = 0 + i = 0 + patterns = ["/", "\\", "|", "*"] + for key, value in benchmark.run_times.items(): + if count == 0: + label = key.replace("_", " ") + else: + label = None + ax.bar([count], [value], label=label, bottom=bottom, color=color, edgecolor="black", hatch=patterns[i]) + bottom += value + i += 1 + else: + total_run_time = np.sum([value for key, value in benchmark.run_times.items()]) + ax.bar([count], [total_run_time], color=color, edgecolor="black") + + # ax.legend() ax.set_ylabel("speed (s)") diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py index 13a64e8168..8e5afb2e8e 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py @@ -9,7 +9,7 @@ from spikeinterface.extractors import read_mearec from spikeinterface.preprocessing import bandpass_filter, zscore, common_reference, scale, highpass_filter, whiten -from spikeinterface.sorters import run_sorter +from spikeinterface.sorters import run_sorter, read_sorter_folder from spikeinterface.widgets import plot_unit_waveforms, plot_gt_performances from spikeinterface.comparison import GroundTruthComparison @@ -184,7 +184,7 @@ def extract_waveforms(self): we.run_extract_waveforms(seed=22051977, **self.job_kwargs) self.waveforms[key] = we - def run_sorters(self): + def run_sorters(self, skip_already_done=True): for case in self.sorter_cases: label = case["label"] print("run sorter", label) @@ -192,9 +192,13 @@ def run_sorters(self): sorter_params = case["sorter_params"] recording = self.recordings[case["recording"]] output_folder = self.folder / f"tmp_sortings_{label}" - sorting = run_sorter( - sorter_name, recording, output_folder, **sorter_params, delete_output_folder=self.delete_output_folder - ) + if output_folder.exists() and skip_already_done: + print('already done') + sorting = read_sorter_folder(output_folder) + else: + sorting = run_sorter( + sorter_name, recording, output_folder, **sorter_params, delete_output_folder=self.delete_output_folder + ) self.sortings[label] = sorting def compute_distances_to_static(self, force=False):