Skip to content

Commit

Permalink
Benchmarks components: plotting utils (#2959)
Browse files Browse the repository at this point in the history
Benchmarks components: plotting utils
  • Loading branch information
yger authored Jul 5, 2024
1 parent 9703af1 commit f8a4331
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 67 deletions.
20 changes: 17 additions & 3 deletions src/spikeinterface/comparison/groundtruthstudy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,6 @@
import numpy as np

from spikeinterface.core import load_extractor, create_sorting_analyzer, load_sorting_analyzer
from spikeinterface.core.core_tools import SIJsonEncoder
from spikeinterface.core.job_tools import split_job_kwargs

from spikeinterface.sorters import run_sorter_jobs, read_sorter_folder

from spikeinterface.qualitymetrics import compute_quality_metrics
Expand Down Expand Up @@ -54,6 +51,7 @@ def __init__(self, study_folder):
self.cases = {}
self.sortings = {}
self.comparisons = {}
self.colors = None

self.scan_folder()

Expand Down Expand Up @@ -175,6 +173,22 @@ def remove_sorting(self, key):
if f.exists():
f.unlink()

def set_colors(self, colors=None, map_name="tab20"):
from spikeinterface.widgets import get_some_colors

if colors is None:
case_keys = list(self.cases.keys())
self.colors = get_some_colors(
case_keys, map_name=map_name, color_engine="matplotlib", shuffle=False, margin=0
)
else:
self.colors = colors

def get_colors(self):
if self.colors is None:
self.set_colors()
return self.colors

def run_sorters(self, case_keys=None, engine="loop", engine_kwargs={}, keep=True, verbose=False):
if case_keys is None:
case_keys = self.cases.keys()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,11 +185,11 @@ def plot_performances_vs_snr(self, case_keys=None, figsize=(15, 15)):
case_keys = list(self.cases.keys())
import pylab as plt

fig, axs = plt.subplots(ncols=1, nrows=3, figsize=figsize)
fig, axes = plt.subplots(ncols=1, nrows=3, figsize=figsize)

for count, k in enumerate(("accuracy", "recall", "precision")):

ax = axs[count]
ax = axes[count]
for key in case_keys:
label = self.cases[key]["label"]

Expand All @@ -211,7 +211,7 @@ def plot_error_metrics(self, metric="cosine", case_keys=None, figsize=(15, 5)):
case_keys = list(self.cases.keys())
import pylab as plt

fig, axs = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False)
fig, axes = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False)

for count, key in enumerate(case_keys):

Expand All @@ -234,21 +234,25 @@ def plot_error_metrics(self, metric="cosine", case_keys=None, figsize=(15, 5)):
else:
distances = sklearn.metrics.pairwise_distances(a, b, metric)

im = axs[0, count].imshow(distances, aspect="auto")
axs[0, count].set_title(metric)
fig.colorbar(im, ax=axs[0, count])
im = axes[0, count].imshow(distances, aspect="auto")
axes[0, count].set_title(metric)
fig.colorbar(im, ax=axes[0, count])
label = self.cases[key]["label"]
axs[0, count].set_title(label)
axes[0, count].set_title(label)

return fig

def plot_metrics_vs_snr(self, metric="agreement", case_keys=None, figsize=(15, 5)):
def plot_metrics_vs_snr(self, metric="agreement", case_keys=None, figsize=(15, 5), axes=None):

if case_keys is None:
case_keys = list(self.cases.keys())
import pylab as plt

fig, axs = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False)
if axes is None:
fig, axes = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False)
axes = axes.flatten()
else:
fig = None

for count, key in enumerate(case_keys):

Expand Down Expand Up @@ -287,13 +291,13 @@ def plot_metrics_vs_snr(self, metric="agreement", case_keys=None, figsize=(15, 5
elif metric == "agreement":
for found, real in zip(matched_ids2[mask], unit_ids1[mask]):
to_plot += [scores.at[real, found]]
axs[0, count].plot(snr_matched, to_plot, ".", label="matched")
axs[0, count].plot(snr_missed, np.zeros(len(snr_missed)), ".", c="r", label="missed")
axs[0, count].set_xlabel("snr")
axs[0, count].set_ylabel(metric)
axes[count].plot(snr_matched, to_plot, ".", label="matched")
axes[count].plot(snr_missed, np.zeros(len(snr_missed)), ".", c="r", label="missed")
axes[count].set_xlabel("snr")
axes[count].set_ylabel(metric)
label = self.cases[key]["label"]
axs[0, count].set_title(label)
axs[0, count].legend()
axes[count].set_title(label)
axes[count].legend()

return fig

Expand All @@ -303,7 +307,7 @@ def plot_metrics_vs_depth_and_snr(self, metric="agreement", case_keys=None, figs
case_keys = list(self.cases.keys())
import pylab as plt

fig, axs = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False)
fig, axes = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False)

for count, key in enumerate(case_keys):

Expand Down Expand Up @@ -348,47 +352,61 @@ def plot_metrics_vs_depth_and_snr(self, metric="agreement", case_keys=None, figs
elif metric == "agreement":
for found, real in zip(matched_ids2[mask], unit_ids1[mask]):
to_plot += [scores.at[real, found]]
axs[0, count].scatter(depth_matched, snr_matched, c=to_plot, label="matched")
axs[0, count].scatter(depth_missed, snr_missed, c=np.zeros(len(snr_missed)), label="missed")
axs[0, count].set_xlabel("depth")
axs[0, count].set_ylabel("snr")
elif metric in ["recall", "precision", "accuracy"]:
to_plot = result["gt_comparison"].get_performance()[metric].values
depth_matched = depth
snr_matched = metrics["snr"]

im = axes[0, count].scatter(depth_matched, snr_matched, c=to_plot, label="matched")
im.set_clim(0, 1)
axes[0, count].scatter(depth_missed, snr_missed, c=np.zeros(len(snr_missed)), label="missed")
axes[0, count].set_xlabel("depth")
axes[0, count].set_ylabel("snr")
label = self.cases[key]["label"]
axs[0, count].set_title(label)
axes[0, count].set_title(label)
if count > 0:
axes[0, count].set_ylabel("")
axes[0, count].set_yticks([], [])
# axs[0, count].legend()

fig.subplots_adjust(right=0.85)
cbar_ax = fig.add_axes([0.9, 0.1, 0.025, 0.75])
fig.colorbar(im, cax=cbar_ax, label=metric)

return fig

def plot_unit_losses(self, case_before, case_after, metric="agreement", figsize=None):
import pylab as plt
def plot_unit_losses(self, cases_before, cases_after, metric="agreement", figsize=None):

fig, axs = plt.subplots(ncols=1, nrows=3, figsize=figsize)
fig, axs = plt.subplots(ncols=len(cases_before), nrows=1, figsize=figsize)

for count, k in enumerate(("accuracy", "recall", "precision")):
for count, (case_before, case_after) in enumerate(zip(cases_before, cases_after)):

ax = axs[count]

# label = self.cases[case_after]["label"]

# positions = self.get_result(case_before)["gt_comparison"].sorting1.get_property("gt_unit_locations")

dataset_key = self.cases[case_before]["dataset"]
rec, gt_sorting1 = self.datasets[dataset_key]
_, gt_sorting1 = self.datasets[dataset_key]
positions = gt_sorting1.get_property("gt_unit_locations")

analyzer = self.get_sorting_analyzer(case_before)
metrics_before = analyzer.get_extension("quality_metrics").get_data()
x = metrics_before["snr"].values

y_before = self.get_result(case_before)["gt_comparison"].get_performance()[k].values
y_after = self.get_result(case_after)["gt_comparison"].get_performance()[k].values
if count < 2:
ax.set_xticks([], [])
elif count == 2:
ax.set_xlabel("depth (um)")
im = ax.scatter(positions[:, 1], x, c=(y_after - y_before), marker=".", s=50, cmap="copper")
fig.colorbar(im, ax=ax)
ax.set_title(k)
y_before = self.get_result(case_before)["gt_comparison"].get_performance()[metric].values
y_after = self.get_result(case_after)["gt_comparison"].get_performance()[metric].values
ax.set_ylabel("depth (um)")
ax.set_ylabel("snr")
if count > 0:
ax.set_ylabel("")
ax.set_yticks([], [])
im = ax.scatter(positions[:, 1], x, c=(y_after - y_before), cmap="coolwarm")
im.set_clim(-1, 1)
# fig.colorbar(im, ax=ax)
# ax.set_title(k)

fig.subplots_adjust(right=0.85)
cbar_ax = fig.add_axes([0.9, 0.1, 0.025, 0.75])
cbar = fig.colorbar(im, cax=cbar_ax, label=metric)
# cbar.set_clim(-1, 1)

return fig

def plot_comparison_clustering(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
import numpy as np
from spikeinterface.sortingcomponents.benchmark.benchmark_tools import Benchmark, BenchmarkStudy
from spikeinterface.core.basesorting import minimum_spike_dtype
from spikeinterface.sortingcomponents.tools import remove_empty_templates
from spikeinterface.core.recording_tools import get_noise_levels
from spikeinterface.core.sparsity import compute_sparsity


class MatchingBenchmark(Benchmark):
Expand Down Expand Up @@ -73,17 +76,15 @@ def plot_agreements(self, case_keys=None, figsize=None):
ax.set_title(self.cases[key]["label"])
plot_agreement_matrix(self.get_result(key)["gt_comparison"], ax=ax)

return fig

def plot_performances_vs_snr(self, case_keys=None, figsize=None):
def plot_performances_vs_snr(self, case_keys=None, figsize=None, metrics=["accuracy", "recall", "precision"]):
if case_keys is None:
case_keys = list(self.cases.keys())

fig, axs = plt.subplots(ncols=1, nrows=3, figsize=figsize)
fig, axs = plt.subplots(ncols=1, nrows=len(metrics), figsize=figsize, squeeze=False)

for count, k in enumerate(("accuracy", "recall", "precision")):
for count, k in enumerate(metrics):

ax = axs[count]
ax = axs[count, 0]
for key in case_keys:
label = self.cases[key]["label"]

Expand Down Expand Up @@ -223,13 +224,13 @@ def plot_unit_counts(self, case_keys=None, figsize=None):

plot_study_unit_counts(self, case_keys, figsize=figsize)

def plot_unit_losses(self, before, after, figsize=None):
def plot_unit_losses(self, before, after, metric=["precision"], figsize=None):

fig, axs = plt.subplots(ncols=1, nrows=3, figsize=figsize)
fig, axs = plt.subplots(ncols=1, nrows=len(metric), figsize=figsize, squeeze=False)

for count, k in enumerate(("accuracy", "recall", "precision")):
for count, k in enumerate(metric):

ax = axs[count]
ax = axs[0, count]

label = self.cases[after]["label"]

Expand All @@ -241,15 +242,20 @@ def plot_unit_losses(self, before, after, figsize=None):

y_before = self.get_result(before)["gt_comparison"].get_performance()[k].values
y_after = self.get_result(after)["gt_comparison"].get_performance()[k].values
if count < 2:
ax.set_xticks([], [])
elif count == 2:
ax.set_xlabel("depth (um)")
im = ax.scatter(positions[:, 1], x, c=(y_after - y_before), marker=".", s=50, cmap="copper")
fig.colorbar(im, ax=ax)
# if count < 2:
# ax.set_xticks([], [])
# elif count == 2:
ax.set_xlabel("depth (um)")
im = ax.scatter(positions[:, 1], x, c=(y_after - y_before), cmap="coolwarm")
fig.colorbar(im, ax=ax, label=k)
im.set_clim(-1, 1)
ax.set_title(k)
ax.set_ylabel("snr")

# fig.subplots_adjust(right=0.85)
# cbar_ax = fig.add_axes([0.9, 0.1, 0.025, 0.75])
# cbar = fig.colorbar(im, cax=cbar_ax, label=metric)

# if count == 2:
# ax.legend()
return fig
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,8 @@ def plot_detected_amplitudes(self, case_keys=None, figsize=(15, 5), detect_thres
abs_threshold = -detect_threshold * noise_levels
ax.plot([abs_threshold, abs_threshold], [ymin, ymax], "k--")

return fig

def plot_deltas_per_cells(self, case_keys=None, figsize=(15, 5)):

if case_keys is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@


from spikeinterface.core import SortingAnalyzer
from spikeinterface import load_extractor, split_job_kwargs, create_sorting_analyzer, load_sorting_analyzer

from spikeinterface import load_extractor, create_sorting_analyzer, load_sorting_analyzer
from spikeinterface.widgets import get_some_colors


import pickle

_key_separator = "_-°°-_"
Expand Down
11 changes: 11 additions & 0 deletions src/spikeinterface/sortingcomponents/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,14 @@ def remove_empty_templates(templates):
probe=templates.probe,
is_scaled=templates.is_scaled,
)


def sigmoid(x, x0, k, b):
return (1 / (1 + np.exp(-k * (x - x0)))) + b


def fit_sigmoid(xdata, ydata, p0=None):
from scipy.optimize import curve_fit

popt, pcov = curve_fit(sigmoid, xdata, ydata, p0)
return popt
14 changes: 7 additions & 7 deletions src/spikeinterface/widgets/gtstudy.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,7 @@ def __init__(
case_keys = list(study.cases.keys())

plot_data = dict(
study=study,
run_times=study.get_run_times(case_keys),
case_keys=case_keys,
study=study, run_times=study.get_run_times(case_keys), case_keys=case_keys, colors=study.get_colors()
)

BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs)
Expand All @@ -48,8 +46,8 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
for i, key in enumerate(dp.case_keys):
label = dp.study.cases[key]["label"]
rt = dp.run_times.loc[key]
self.ax.bar(i, rt, width=0.8, label=label)

self.ax.bar(i, rt, width=0.8, label=label, facecolor=dp.colors[key])
self.ax.set_ylabel("run time (s)")
self.ax.legend()


Expand Down Expand Up @@ -167,6 +165,8 @@ def __init__(
case_keys=case_keys,
)

self.colors = study.get_colors()

BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs)

def plot_matplotlib(self, data_plot, **backend_kwargs):
Expand All @@ -192,7 +192,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
label = study.cases[key]["label"]
val = perfs.xs(key).loc[:, performance_name].values
val = np.sort(val)[::-1]
ax.plot(val, label=label)
ax.plot(val, label=label, c=self.colors[key])
ax.set_title(performance_name)
if count == len(dp.performance_names) - 1:
ax.legend(bbox_to_anchor=(0.05, 0.05), loc="lower left", framealpha=0.8)
Expand All @@ -207,7 +207,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
x = study.get_metrics(key).loc[:, metric_name].values
y = perfs.xs(key).loc[:, performance_name].values
label = study.cases[key]["label"]
ax.scatter(x, y, s=10, label=label)
ax.scatter(x, y, s=10, label=label, color=self.colors[key])
max_metric = max(max_metric, np.max(x))
ax.set_title(performance_name)
ax.set_xlim(0, max_metric * 1.05)
Expand Down

0 comments on commit f8a4331

Please sign in to comment.