Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some improvement in Study and related widgets #2128

Merged
merged 17 commits into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 30 additions & 26 deletions src/spikeinterface/comparison/groundtruthstudy.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def create(cls, study_folder, datasets={}, cases={}, levels=None):
(study_folder / "sortings").mkdir()
(study_folder / "sortings" / "run_logs").mkdir()
(study_folder / "metrics").mkdir()
(study_folder / "comparisons").mkdir()

for key, (rec, gt_sorting) in datasets.items():
assert "/" not in key, "'/' cannot be in the key name!"
Expand Down Expand Up @@ -127,16 +128,17 @@ def scan_folder(self):
with open(self.folder / "cases.pickle", "rb") as f:
self.cases = pickle.load(f)

self.sortings = {k: None for k in self.cases}
self.comparisons = {k: None for k in self.cases}

self.sortings = {}
for key in self.cases:
sorting_folder = self.folder / "sortings" / self.key_to_str(key)
if sorting_folder.exists():
sorting = load_extractor(sorting_folder)
else:
sorting = None
self.sortings[key] = sorting
self.sortings[key] = load_extractor(sorting_folder)

comparison_file = self.folder / "comparisons" / (self.key_to_str(key) + ".pickle")
if comparison_file.exists():
with open(comparison_file, mode="rb") as f:
self.comparisons[key] = pickle.load(f)

def __repr__(self):
t = f"{self.__class__.__name__} {self.folder.stem} \n"
Expand All @@ -155,6 +157,16 @@ def key_to_str(self, key):
else:
raise ValueError("Keys for cases must str or tuple")

def remove_sorting(self, key):
sorting_folder = self.folder / "sortings" / self.key_to_str(key)
log_file = self.folder / "sortings" / "run_logs" / f"{self.key_to_str(key)}.json"
comparison_file = self.folder / "comparisons" / self.key_to_str(key)
if sorting_folder.exists():
shutil.rmtree(sorting_folder)
for f in (log_file, comparison_file):
if f.exists():
f.unlink()

def run_sorters(self, case_keys=None, engine="loop", engine_kwargs={}, keep=True, verbose=False):
if case_keys is None:
case_keys = self.cases.keys()
Expand All @@ -178,12 +190,7 @@ def run_sorters(self, case_keys=None, engine="loop", engine_kwargs={}, keep=True
self.copy_sortings(case_keys=[key])
continue

if sorting_exists:
# delete older sorting + log before running sorters
shutil.rmtree(sorting_folder)
log_file = self.folder / "sortings" / "run_logs" / f"{self.key_to_str(key)}.json"
if log_file.exists():
log_file.unlink()
self.remove_sorting(key)

if sorter_folder_exists:
shutil.rmtree(sorter_folder)
Expand Down Expand Up @@ -228,10 +235,7 @@ def copy_sortings(self, case_keys=None, force=True):
if sorting is not None:
if sorting_folder.exists():
if force:
# delete folder + log
shutil.rmtree(sorting_folder)
if log_file.exists():
log_file.unlink()
self.remove_sorting(key)
else:
continue

Expand All @@ -255,6 +259,10 @@ def run_comparisons(self, case_keys=None, comparison_class=GroundTruthComparison
comp = comparison_class(gt_sorting, sorting, **kwargs)
self.comparisons[key] = comp

comparison_file = self.folder / "comparisons" / (self.key_to_str(key) + ".pickle")
with open(comparison_file, mode="wb") as f:
pickle.dump(comp, f)

def get_run_times(self, case_keys=None):
import pandas as pd

Expand Down Expand Up @@ -288,20 +296,16 @@ def extract_waveforms_gt(self, case_keys=None, **extract_kwargs):
recording, gt_sorting = self.datasets[dataset_key]
we = extract_waveforms(recording, gt_sorting, folder=wf_folder, **extract_kwargs)

def get_waveform_extractor(self, key):
# some recording are not dumpable to json and the waveforms extactor need it!
# so we load it with and put after
# this should be fixed in PR 2027 so remove this after
def get_waveform_extractor(self, case_key=None, dataset_key=None):
if case_key is not None:
dataset_key = self.cases[case_key]["dataset"]

dataset_key = self.cases[key]["dataset"]
wf_folder = self.folder / "waveforms" / self.key_to_str(dataset_key)
we = load_waveforms(wf_folder, with_recording=False)
recording, _ = self.datasets[dataset_key]
we.set_recording(recording)
we = load_waveforms(wf_folder, with_recording=True)
return we

def get_templates(self, key, mode="average"):
we = self.get_waveform_extractor(key)
we = self.get_waveform_extractor(case_key=key)
templates = we.get_all_templates(mode=mode)
return templates

Expand Down Expand Up @@ -366,7 +370,7 @@ def get_performance_by_unit(self, case_keys=None):
perf_by_unit.append(perf)

perf_by_unit = pd.concat(perf_by_unit)
perf_by_unit = perf_by_unit.set_index(self.levels)
perf_by_unit = perf_by_unit.set_index(self.levels).sort_index()
return perf_by_unit

def get_count_units(self, case_keys=None, well_detected_score=None, redundant_score=None, overmerged_score=None):
Expand Down
18 changes: 3 additions & 15 deletions src/spikeinterface/sortingcomponents/clustering/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
from .tools import aggregate_sparse_features, FeaturesLoader, compute_template_from_sparse


DEBUG = False


def merge_clusters(
peaks,
peak_labels,
Expand Down Expand Up @@ -81,7 +84,6 @@ def merge_clusters(
**job_kwargs,
)

DEBUG = False
if DEBUG:
import matplotlib.pyplot as plt

Expand Down Expand Up @@ -224,17 +226,13 @@ def agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="full"
else:
raise ValueError

# DEBUG = True
DEBUG = False
if DEBUG:
import matplotlib.pyplot as plt

fig = plt.figure()
nx.draw_networkx(sub_graph)
plt.show()

# DEBUG = True
DEBUG = False
if DEBUG:
import matplotlib.pyplot as plt

Expand Down Expand Up @@ -551,15 +549,7 @@ def merge(
else:
final_shift = 0

# DEBUG = True
DEBUG = False

# if DEBUG and is_merge:
# if DEBUG and (overlap > 0.1 and overlap <0.3):
if DEBUG:
# if DEBUG and not is_merge:
# if DEBUG and (overlap > 0.05 and overlap <0.25):
# if label0 == 49 and label1== 65:
import matplotlib.pyplot as plt

flatten_wfs0 = wfs0.swapaxes(1, 2).reshape(wfs0.shape[0], -1)
Expand Down Expand Up @@ -674,8 +664,6 @@ def merge(
final_shift = 0
merge_value = np.nan

# DEBUG = False
DEBUG = True
if DEBUG and normed_diff < 0.2:
# if DEBUG:

Expand Down
87 changes: 0 additions & 87 deletions src/spikeinterface/widgets/agreement_matrix.py

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import numpy as np
from warnings import warn

from .base import BaseWidget, to_attr
from .utils import get_unit_colors


class ConfusionMatrixWidget(BaseWidget):
Expand Down Expand Up @@ -77,3 +75,85 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
N1 + 0.5,
-0.5,
)


class AgreementMatrixWidget(BaseWidget):
"""
Plots sorting comparison agreement matrix.

Parameters
----------
sorting_comparison: GroundTruthComparison or SymmetricSortingComparison
The sorting comparison object.
Can optionally be symmetric if given a SymmetricSortingComparison
ordered: bool, default: True
Order units with best agreement scores.
If True, agreement scores can be seen along a diagonal
count_text: bool, default: True
If True counts are displayed as text
unit_ticks: bool, default: True
If True unit tick labels are displayed

"""

def __init__(
self, sorting_comparison, ordered=True, count_text=True, unit_ticks=True, backend=None, **backend_kwargs
):
plot_data = dict(
sorting_comparison=sorting_comparison,
ordered=ordered,
count_text=count_text,
unit_ticks=unit_ticks,
)
BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs)

def plot_matplotlib(self, data_plot, **backend_kwargs):
import matplotlib.pyplot as plt
from .utils_matplotlib import make_mpl_figure

dp = to_attr(data_plot)

self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs)

comp = dp.sorting_comparison

if dp.ordered:
scores = comp.get_ordered_agreement_scores()
else:
scores = comp.agreement_scores

N1 = scores.shape[0]
N2 = scores.shape[1]

unit_ids1 = scores.index.values
unit_ids2 = scores.columns.values

# Using matshow here just because it sets the ticks up nicely. imshow is faster.
self.ax.matshow(scores.values, cmap="Greens")

if dp.count_text:
for i, u1 in enumerate(unit_ids1):
u2 = comp.best_match_12[u1]
if u2 != -1:
j = np.where(unit_ids2 == u2)[0][0]

self.ax.text(j, i, "{:0.2f}".format(scores.at[u1, u2]), ha="center", va="center", color="white")

# Major ticks
self.ax.xaxis.tick_bottom()

# Labels for major ticks
if dp.unit_ticks:
self.ax.set_xticks(np.arange(0, N2))
self.ax.set_yticks(np.arange(0, N1))
self.ax.set_yticklabels(scores.index)
self.ax.set_xticklabels(scores.columns)

self.ax.set_xlabel(comp.name_list[1])
self.ax.set_ylabel(comp.name_list[0])

self.ax.set_xlim(-0.5, N2 - 0.5)
self.ax.set_ylim(
N1 - 0.5,
-0.5,
)
Loading