Skip to content

Commit

Permalink
Merge branch 'main' into add_template_generation_function
Browse files Browse the repository at this point in the history
  • Loading branch information
h-mayorquin authored Apr 26, 2024
2 parents c99d57f + 747fc0b commit 2caa539
Show file tree
Hide file tree
Showing 8 changed files with 496 additions and 818 deletions.
544 changes: 160 additions & 384 deletions src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -1,21 +1,17 @@
from __future__ import annotations

from spikeinterface.postprocessing import compute_template_similarity
from spikeinterface.sortingcomponents.matching import find_spikes_from_templates
from spikeinterface.core.template import Templates
from spikeinterface.core import NumpySorting
from spikeinterface.comparison import CollisionGTComparison, compare_sorter_to_ground_truth
from spikeinterface.widgets import (
plot_agreement_matrix,
plot_comparison_collision_by_similarity,
)

from pathlib import Path
import pylab as plt
import matplotlib.patches as mpatches
import numpy as np
import pandas as pd
from .benchmark_tools import BenchmarkStudy, Benchmark
from spikeinterface.sortingcomponents.benchmark.benchmark_tools import Benchmark, BenchmarkStudy
from spikeinterface.core.basesorting import minimum_spike_dtype


Expand Down Expand Up @@ -173,3 +169,74 @@ def plot_comparison_matching(
ax.set_xticks([])
ax.set_yticks([])
plt.tight_layout(h_pad=0, w_pad=0)

def get_count_units(self, case_keys=None, well_detected_score=None, redundant_score=None, overmerged_score=None):
import pandas as pd

if case_keys is None:
case_keys = list(self.cases.keys())

if isinstance(case_keys[0], str):
index = pd.Index(case_keys, name=self.levels)
else:
index = pd.MultiIndex.from_tuples(case_keys, names=self.levels)

columns = ["num_gt", "num_sorter", "num_well_detected"]
comp = self.get_result(case_keys[0])["gt_comparison"]
if comp.exhaustive_gt:
columns.extend(["num_false_positive", "num_redundant", "num_overmerged", "num_bad"])
count_units = pd.DataFrame(index=index, columns=columns, dtype=int)

for key in case_keys:
comp = self.get_result(key)["gt_comparison"]
assert comp is not None, "You need to do study.run_comparisons() first"

gt_sorting = comp.sorting1
sorting = comp.sorting2

count_units.loc[key, "num_gt"] = len(gt_sorting.get_unit_ids())
count_units.loc[key, "num_sorter"] = len(sorting.get_unit_ids())
count_units.loc[key, "num_well_detected"] = comp.count_well_detected_units(well_detected_score)

if comp.exhaustive_gt:
count_units.loc[key, "num_redundant"] = comp.count_redundant_units(redundant_score)
count_units.loc[key, "num_overmerged"] = comp.count_overmerged_units(overmerged_score)
count_units.loc[key, "num_false_positive"] = comp.count_false_positive_units(redundant_score)
count_units.loc[key, "num_bad"] = comp.count_bad_units()

return count_units

def plot_unit_counts(self, case_keys=None, figsize=None):
from spikeinterface.widgets.widget_list import plot_study_unit_counts

plot_study_unit_counts(self, case_keys, figsize=figsize)

def plot_unit_losses(self, before, after, figsize=None):

fig, axs = plt.subplots(ncols=1, nrows=3, figsize=figsize)

for count, k in enumerate(("accuracy", "recall", "precision")):

ax = axs[count]

label = self.cases[after]["label"]

positions = self.get_result(before)["gt_comparison"].sorting1.get_property("gt_unit_locations")

analyzer = self.get_sorting_analyzer(before)
metrics_before = analyzer.get_extension("quality_metrics").get_data()
x = metrics_before["snr"].values

y_before = self.get_result(before)["gt_comparison"].get_performance()[k].values
y_after = self.get_result(after)["gt_comparison"].get_performance()[k].values
if count < 2:
ax.set_xticks([], [])
elif count == 2:
ax.set_xlabel("depth (um)")
im = ax.scatter(positions[:, 1], x, c=(y_after - y_before), marker=".", s=50, cmap="copper")
fig.colorbar(im, ax=ax)
ax.set_title(k)
ax.set_ylabel("snr")

# if count == 2:
# ax.legend()
Loading

0 comments on commit 2caa539

Please sign in to comment.