Skip to content

Commit

Permalink
Merge pull request #2032 from samuelgarcia/drift_paper_figure
Browse files Browse the repository at this point in the history
Minor changes on drift benchmark for figures
  • Loading branch information
alejoe91 authored Sep 29, 2023
2 parents 243a30c + 02c17b9 commit 2f354c4
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ def plot_errors_several_benchmarks(benchmarks, axes=None, show_legend=True, colo
mean_error = np.sqrt(np.mean((errors) ** 2, axis=1))
depth_error = np.sqrt(np.mean((errors) ** 2, axis=0))

axes[0].plot(benchmark.temporal_bins, mean_error, label=benchmark.title, color=c)
axes[0].plot(benchmark.temporal_bins, mean_error, lw=1, label=benchmark.title, color=c)
parts = axes[1].violinplot(mean_error, [count], showmeans=True)
if c is not None:
for pc in parts["bodies"]:
Expand All @@ -500,8 +500,8 @@ def plot_errors_several_benchmarks(benchmarks, axes=None, show_legend=True, colo
axes[2].plot(benchmark.spatial_bins, depth_error, label=benchmark.title, color=c)

ax0 = ax = axes[0]
ax.set_xlabel("time [s]")
ax.set_ylabel("error [um]")
ax.set_xlabel("Time [s]")
ax.set_ylabel("Error [μm]")
if show_legend:
ax.legend()
_simpleaxis(ax)
Expand All @@ -514,7 +514,7 @@ def plot_errors_several_benchmarks(benchmarks, axes=None, show_legend=True, colo

ax2 = axes[2]
ax2.set_yticks([])
ax2.set_xlabel("depth [um]")
ax2.set_xlabel("Depth [μm]")
# ax.set_ylabel('error')
channel_positions = benchmark.recording.get_channel_locations()
probe_y_min, probe_y_max = channel_positions[:, 1].min(), channel_positions[:, 1].max()
Expand Down Expand Up @@ -584,23 +584,30 @@ def plot_motions_several_benchmarks(benchmarks):
_simpleaxis(ax)


def plot_speed_several_benchmarks(benchmarks, ax=None, colors=None):
def plot_speed_several_benchmarks(benchmarks, detailed=True, ax=None, colors=None):
if ax is None:
fig, ax = plt.subplots(figsize=(5, 5))

for count, benchmark in enumerate(benchmarks):
color = colors[count] if colors is not None else None
bottom = 0
i = 0
patterns = ["/", "\\", "|", "*"]
for key, value in benchmark.run_times.items():
if count == 0:
label = key.replace("_", " ")
else:
label = None
ax.bar([count], [value], label=label, bottom=bottom, color=color, edgecolor="black", hatch=patterns[i])
bottom += value
i += 1

if detailed:
bottom = 0
i = 0
patterns = ["/", "\\", "|", "*"]
for key, value in benchmark.run_times.items():
if count == 0:
label = key.replace("_", " ")
else:
label = None
ax.bar([count], [value], label=label, bottom=bottom, color=color, edgecolor="black", hatch=patterns[i])
bottom += value
i += 1
else:
total_run_time = np.sum([value for key, value in benchmark.run_times.items()])
ax.bar([count], [total_run_time], color=color, edgecolor="black")



# ax.legend()
ax.set_ylabel("speed (s)")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from spikeinterface.extractors import read_mearec
from spikeinterface.preprocessing import bandpass_filter, zscore, common_reference, scale, highpass_filter, whiten
from spikeinterface.sorters import run_sorter
from spikeinterface.sorters import run_sorter, read_sorter_folder
from spikeinterface.widgets import plot_unit_waveforms, plot_gt_performances

from spikeinterface.comparison import GroundTruthComparison
Expand Down Expand Up @@ -184,17 +184,21 @@ def extract_waveforms(self):
we.run_extract_waveforms(seed=22051977, **self.job_kwargs)
self.waveforms[key] = we

def run_sorters(self):
def run_sorters(self, skip_already_done=True):
for case in self.sorter_cases:
label = case["label"]
print("run sorter", label)
sorter_name = case["sorter_name"]
sorter_params = case["sorter_params"]
recording = self.recordings[case["recording"]]
output_folder = self.folder / f"tmp_sortings_{label}"
sorting = run_sorter(
sorter_name, recording, output_folder, **sorter_params, delete_output_folder=self.delete_output_folder
)
if output_folder.exists() and skip_already_done:
print('already done')
sorting = read_sorter_folder(output_folder)
else:
sorting = run_sorter(
sorter_name, recording, output_folder, **sorter_params, delete_output_folder=self.delete_output_folder
)
self.sortings[label] = sorting

def compute_distances_to_static(self, force=False):
Expand Down

0 comments on commit 2f354c4

Please sign in to comment.