Skip to content

Commit

Permalink
Some more cosmetics.
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelgarcia committed Oct 25, 2023
1 parent 79ac8d1 commit 5e7a57f
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 13 deletions.
12 changes: 6 additions & 6 deletions src/spikeinterface/widgets/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,17 +142,17 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
self.ax.text(j, i, "{:0.2f}".format(scores.at[u1, u2]), ha="center", va="center", color="white")

# Major ticks
self.ax.set_xticks(np.arange(0, N2))
self.ax.set_yticks(np.arange(0, N1))
self.ax.xaxis.tick_bottom()

# Labels for major ticks
if dp.unit_ticks:
self.ax.set_yticklabels(scores.index, fontsize=12)
self.ax.set_xticklabels(scores.columns, fontsize=12)
self.ax.set_xticks(np.arange(0, N2))
self.ax.set_yticks(np.arange(0, N1))
self.ax.set_yticklabels(scores.index)
self.ax.set_xticklabels(scores.columns)

self.ax.set_xlabel(comp.name_list[1], fontsize=20)
self.ax.set_ylabel(comp.name_list[0], fontsize=20)
self.ax.set_xlabel(comp.name_list[1])
self.ax.set_ylabel(comp.name_list[0])

self.ax.set_xlim(-0.5, N2 - 0.5)
self.ax.set_ylim(
Expand Down
26 changes: 19 additions & 7 deletions src/spikeinterface/widgets/gtstudy.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
ax.plot(val, label=label)
ax.set_title(performance_name)
if count == 0:
ax.legend()
ax.legend(loc='upper right')

elif dp.mode == "snr":

Expand All @@ -203,13 +203,13 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
x = study.get_metrics(key).loc[:, metric_name].values
y = perfs.xs(key).loc[:, performance_name].values
label = study.cases[key]["label"]
ax.scatter(x, y, label=label)
ax.scatter(x, y, s=10, label=label)
max_metric = max(max_metric, np.max(x))
ax.set_title(performance_name)
ax.set_xlim(0, max_metric * 1.05)
ax.set_ylim(0, 1.05)
if count == 0:
ax.legend()
ax.legend(loc='lower right')


elif dp.mode == "swarm":
Expand Down Expand Up @@ -245,7 +245,7 @@ class StudyAgreementMatrix(BaseWidget):
def __init__(
self,
study,
ordered=True, count_text=True,
ordered=True,
case_keys=None,
backend=None,
**backend_kwargs,
Expand All @@ -257,7 +257,6 @@ def __init__(
study=study,
case_keys=case_keys,
ordered=ordered,
count_text=count_text,
)

BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs)
Expand All @@ -276,9 +275,22 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
for count, key in enumerate(dp.case_keys):
ax = self.axes.flatten()[count]
comp = study.comparisons[key]
AgreementMatrixWidget(comp, ordered=dp.ordered, count_text=dp.count_text, backend='matplotlib', ax=ax)
unit_ticks = len(comp.sorting1.unit_ids) <= 16
count_text = len(comp.sorting1.unit_ids) <= 16


AgreementMatrixWidget(comp, ordered=dp.ordered, count_text=count_text, unit_ticks=unit_ticks, backend='matplotlib', ax=ax)
label = study.cases[key]["label"]
ax.set_title(label)
ax.set_xlabel(label)

if count > 0:
ax.set_ylabel(None)
ax.set_yticks([])
ax.set_xticks([])

# ax0 = self.axes.flatten()[0]
# for ax in self.axes.flatten()[1:]:
# ax.sharey(ax0)


class StudySummary(BaseWidget):
Expand Down

0 comments on commit 5e7a57f

Please sign in to comment.