Skip to content

Commit

Permalink
Tdc peeler (#3466)
Browse files Browse the repository at this point in the history
Improving the Peeler
  • Loading branch information
samuelgarcia authored Oct 11, 2024
1 parent bbf7daf commit 0ae32e7
Show file tree
Hide file tree
Showing 5 changed files with 608 additions and 257 deletions.
73 changes: 8 additions & 65 deletions src/spikeinterface/benchmark/benchmark_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def run(self, **job_kwargs):
sorting["unit_index"] = spikes["cluster_index"]
sorting["segment_index"] = spikes["segment_index"]
sorting = NumpySorting(sorting, self.recording.sampling_frequency, unit_ids)
self.result = {"sorting": sorting}
self.result = {"sorting": sorting, "spikes": spikes}
self.result["templates"] = self.templates

def compute_result(self, with_collision=False, **result_params):
Expand All @@ -45,6 +45,7 @@ def compute_result(self, with_collision=False, **result_params):

_run_key_saved = [
("sorting", "sorting"),
("spikes", "npy"),
("templates", "zarr_templates"),
]
_result_key_saved = [("gt_collision", "pickle"), ("gt_comparison", "pickle")]
Expand All @@ -71,6 +72,11 @@ def plot_performances_vs_snr(self, **kwargs):

return plot_performances_vs_snr(self, **kwargs)

def plot_performances_comparison(self, **kwargs):
from .benchmark_plot_tools import plot_performances_comparison

return plot_performances_comparison(self, **kwargs)

def plot_collisions(self, case_keys=None, figsize=None):
if case_keys is None:
case_keys = list(self.cases.keys())
Expand All @@ -90,70 +96,6 @@ def plot_collisions(self, case_keys=None, figsize=None):

return fig

def plot_comparison_matching(
self,
case_keys=None,
performance_names=["accuracy", "recall", "precision"],
colors=["g", "b", "r"],
ylim=(-0.1, 1.1),
figsize=None,
):

if case_keys is None:
case_keys = list(self.cases.keys())

num_methods = len(case_keys)
import pylab as plt

fig, axs = plt.subplots(ncols=num_methods, nrows=num_methods, figsize=(10, 10))
for i, key1 in enumerate(case_keys):
for j, key2 in enumerate(case_keys):
if len(axs.shape) > 1:
ax = axs[i, j]
else:
ax = axs[j]
comp1 = self.get_result(key1)["gt_comparison"]
comp2 = self.get_result(key2)["gt_comparison"]
if i <= j:
for performance, color in zip(performance_names, colors):
perf1 = comp1.get_performance()[performance]
perf2 = comp2.get_performance()[performance]
ax.plot(perf2, perf1, ".", label=performance, color=color)

ax.plot([0, 1], [0, 1], "k--", alpha=0.5)
ax.set_ylim(ylim)
ax.set_xlim(ylim)
ax.spines[["right", "top"]].set_visible(False)
ax.set_aspect("equal")

label1 = self.cases[key1]["label"]
label2 = self.cases[key2]["label"]
if j == i:
ax.set_ylabel(f"{label1}")
else:
ax.set_yticks([])
if i == j:
ax.set_xlabel(f"{label2}")
else:
ax.set_xticks([])
if i == num_methods - 1 and j == num_methods - 1:
patches = []
import matplotlib.patches as mpatches

for color, name in zip(colors, performance_names):
patches.append(mpatches.Patch(color=color, label=name))
ax.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc="upper left", borderaxespad=0.0)
else:
ax.spines["bottom"].set_visible(False)
ax.spines["left"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.set_xticks([])
ax.set_yticks([])
plt.tight_layout(h_pad=0, w_pad=0)

return fig

def get_count_units(self, case_keys=None, well_detected_score=None, redundant_score=None, overmerged_score=None):
import pandas as pd

Expand Down Expand Up @@ -196,6 +138,7 @@ def plot_unit_counts(self, case_keys=None, figsize=None):
plot_study_unit_counts(self, case_keys, figsize=figsize)

def plot_unit_losses(self, before, after, metric=["precision"], figsize=None):
import matplotlib.pyplot as plt

fig, axs = plt.subplots(ncols=1, nrows=len(metric), figsize=figsize, squeeze=False)

Expand Down
64 changes: 63 additions & 1 deletion src/spikeinterface/benchmark/benchmark_plot_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,9 +235,71 @@ def plot_performances_vs_snr(study, case_keys=None, figsize=None, metrics=["accu
ax.scatter(x, y, marker=".", label=label)
ax.set_title(k)

ax.set_ylim(0, 1.05)
ax.set_ylim(-0.05, 1.05)

if count == 2:
ax.legend()

return fig


def plot_performances_comparison(
study,
case_keys=None,
figsize=None,
metrics=["accuracy", "recall", "precision"],
colors=["g", "b", "r"],
ylim=(-0.1, 1.1),
):
import matplotlib.pyplot as plt

if case_keys is None:
case_keys = list(study.cases.keys())

num_methods = len(case_keys)
assert num_methods >= 2, "plot_performances_comparison need at least 2 cases!"

fig, axs = plt.subplots(ncols=num_methods - 1, nrows=num_methods - 1, figsize=(10, 10), squeeze=False)
for i, key1 in enumerate(case_keys):
for j, key2 in enumerate(case_keys):

if i < j:
ax = axs[i, j - 1]

comp1 = study.get_result(key1)["gt_comparison"]
comp2 = study.get_result(key2)["gt_comparison"]

for performance, color in zip(metrics, colors):
perf1 = comp1.get_performance()[performance]
perf2 = comp2.get_performance()[performance]
ax.scatter(perf2, perf1, marker=".", label=performance, color=color)

ax.plot([0, 1], [0, 1], "k--", alpha=0.5)
ax.set_ylim(ylim)
ax.set_xlim(ylim)
ax.spines[["right", "top"]].set_visible(False)
ax.set_aspect("equal")

label1 = study.cases[key1]["label"]
label2 = study.cases[key2]["label"]

if i == j - 1:
ax.set_xlabel(label2)
ax.set_ylabel(label1)

else:
if j >= 1 and i < num_methods - 1:
ax = axs[i, j - 1]
ax.spines[["right", "top", "left", "bottom"]].set_visible(False)
ax.set_xticks([])
ax.set_yticks([])

ax = axs[num_methods - 2, 0]
patches = []
from matplotlib.patches import Patch

for color, name in zip(colors, metrics):
patches.append(Patch(color=color, label=name))
ax.legend(handles=patches)
fig.tight_layout()
return fig
Loading

0 comments on commit 0ae32e7

Please sign in to comment.