Skip to content

Commit

Permalink
Oups.
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelgarcia committed Sep 19, 2023
1 parent 4501289 commit 625ff5e
Show file tree
Hide file tree
Showing 4 changed files with 347 additions and 0 deletions.
91 changes: 91 additions & 0 deletions src/spikeinterface/widgets/agreement_matrix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import numpy as np
from warnings import warn

from .base import BaseWidget, to_attr
from .utils import get_unit_colors



class AgreementMatrixWidget(BaseWidget):
"""
Plot unit depths
Parameters
----------
sorting_comparison: GroundTruthComparison or SymmetricSortingComparison
The sorting comparison object.
Symetric or not.
ordered: bool
Order units with best agreement scores.
This enable to see agreement on a diagonal.
count_text: bool
If True counts are displayed as text
unit_ticks: bool
If True unit tick labels are displayed
"""

def __init__(
self, sorting_comparison, ordered=True, count_text=True, unit_ticks=True,
backend=None, **backend_kwargs
):
plot_data = dict(
sorting_comparison=sorting_comparison,
ordered=ordered,
count_text=count_text,
unit_ticks=unit_ticks,
)
BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs)

def plot_matplotlib(self, data_plot, **backend_kwargs):
import matplotlib.pyplot as plt
from .utils_matplotlib import make_mpl_figure

dp = to_attr(data_plot)

self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs)

comp = dp.sorting_comparison

if dp.ordered:
scores = comp.get_ordered_agreement_scores()
else:
scores = comp.agreement_scores

N1 = scores.shape[0]
N2 = scores.shape[1]

unit_ids1 = scores.index.values
unit_ids2 = scores.columns.values

# Using matshow here just because it sets the ticks up nicely. imshow is faster.
self.ax.matshow(scores.values, cmap="Greens")

if dp.count_text:
for i, u1 in enumerate(unit_ids1):
u2 = comp.best_match_12[u1]
if u2 != -1:
j = np.where(unit_ids2 == u2)[0][0]

self.ax.text(j, i, "{:0.2f}".format(scores.at[u1, u2]), ha="center", va="center", color="white")

# Major ticks
self.ax.set_xticks(np.arange(0, N2))
self.ax.set_yticks(np.arange(0, N1))
self.ax.xaxis.tick_bottom()

# Labels for major ticks
if dp.unit_ticks:
self.ax.set_yticklabels(scores.index, fontsize=12)
self.ax.set_xticklabels(scores.columns, fontsize=12)

self.ax.set_xlabel(comp.name_list[1], fontsize=20)
self.ax.set_ylabel(comp.name_list[0], fontsize=20)

self.ax.set_xlim(-0.5, N2 - 0.5)
self.ax.set_ylim(
N1 - 0.5,
-0.5,
)


83 changes: 83 additions & 0 deletions src/spikeinterface/widgets/confusion_matrix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import numpy as np
from warnings import warn

from .base import BaseWidget, to_attr
from .utils import get_unit_colors



class ConfusionMatrixWidget(BaseWidget):
"""
Plot unit depths
Parameters
----------
gt_comparison: GroundTruthComparison
The ground truth sorting comparison object
count_text: bool
If True counts are displayed as text
unit_ticks: bool
If True unit tick labels are displayed
"""

def __init__(
self, gt_comparison, count_text=True, unit_ticks=True,
backend=None, **backend_kwargs
):
plot_data = dict(
gt_comparison=gt_comparison,
count_text=count_text,
unit_ticks=unit_ticks,
)
BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs)

def plot_matplotlib(self, data_plot, **backend_kwargs):
import matplotlib.pyplot as plt
from .utils_matplotlib import make_mpl_figure

dp = to_attr(data_plot)

self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs)

comp = dp.gt_comparison

confusion_matrix = comp.get_confusion_matrix()
N1 = confusion_matrix.shape[0] - 1
N2 = confusion_matrix.shape[1] - 1

# Using matshow here just because it sets the ticks up nicely. imshow is faster.
self.ax.matshow(confusion_matrix.values, cmap="Greens")

if dp.count_text:
for (i, j), z in np.ndenumerate(confusion_matrix.values):
if z != 0:
if z > np.max(confusion_matrix.values) / 2.0:
self.ax.text(j, i, "{:d}".format(z), ha="center", va="center", color="white")
else:
self.ax.text(j, i, "{:d}".format(z), ha="center", va="center", color="black")

self.ax.axhline(int(N1 - 1) + 0.5, color="black")
self.ax.axvline(int(N2 - 1) + 0.5, color="black")

# Major ticks
self.ax.set_xticks(np.arange(0, N2 + 1))
self.ax.set_yticks(np.arange(0, N1 + 1))
self.ax.xaxis.tick_bottom()

# Labels for major ticks
if dp.unit_ticks:
self.ax.set_yticklabels(confusion_matrix.index, fontsize=12)
self.ax.set_xticklabels(confusion_matrix.columns, fontsize=12)
else:
self.ax.set_xticklabels(np.append([""] * N2, "FN"), fontsize=10)
self.ax.set_yticklabels(np.append([""] * N1, "FP"), fontsize=10)

self.ax.set_xlabel(comp.name_list[1], fontsize=20)
self.ax.set_ylabel(comp.name_list[0], fontsize=20)

self.ax.set_xlim(-0.5, N2 + 0.5)
self.ax.set_ylim(
N1 + 0.5,
-0.5,
)
78 changes: 78 additions & 0 deletions src/spikeinterface/widgets/probe_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import numpy as np
from warnings import warn

from .base import BaseWidget, to_attr, default_backend_kwargs
from .utils import get_unit_colors



class ProbeMapWidget(BaseWidget):
"""
Plot the probe of a recording.
Parameters
----------
recording: RecordingExtractor
The recording extractor object
channel_ids: list
The channel ids to display
with_channel_ids: bool False default
Add channel ids text on the probe
**plot_probe_kwargs: keyword arguments for probeinterface.plotting.plot_probe_group() function
"""

def __init__(
self, recording, channel_ids=None, with_channel_ids=False,
backend=None, **backend_or_plot_probe_kwargs
):

# split backend_or_plot_probe_kwargs
backend_kwargs = dict()
plot_probe_kwargs = dict()
backend = self.check_backend(backend)
for k, v in backend_or_plot_probe_kwargs.items():
if k in default_backend_kwargs[backend]:
backend_kwargs[k] = v
else:
plot_probe_kwargs[k] = v

plot_data = dict(
recording=recording,
channel_ids=channel_ids,
with_channel_ids=with_channel_ids,
plot_probe_kwargs=plot_probe_kwargs,
)
BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs)

def plot_matplotlib(self, data_plot, **backend_kwargs):
import matplotlib.pyplot as plt
from .utils_matplotlib import make_mpl_figure
from probeinterface.plotting import get_auto_lims, plot_probe

dp = to_attr(data_plot)

plot_probe_kwargs = dp.plot_probe_kwargs

self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs)

probegroup = dp.recording.get_probegroup()

xlims, ylims, zlims = get_auto_lims(probegroup.probes[0])
for i, probe in enumerate(probegroup.probes):
xlims2, ylims2, _ = get_auto_lims(probe)
xlims = min(xlims[0], xlims2[0]), max(xlims[1], xlims2[1])
ylims = min(ylims[0], ylims2[0]), max(ylims[1], ylims2[1])

plot_probe_kwargs["title"] = False
pos = 0
text_on_contact = None
for i, probe in enumerate(probegroup.probes):
n = probe.get_contact_count()
if dp.with_channel_ids:
text_on_contact = dp.recording.channel_ids[pos : pos + n]
pos += n
plot_probe(probe, ax=self.ax, text_on_contact=text_on_contact, **plot_probe_kwargs)

self.ax.set_xlim(*xlims)
self.ax.set_ylim(*ylims)
95 changes: 95 additions & 0 deletions src/spikeinterface/widgets/rasters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import numpy as np
from warnings import warn

from .base import BaseWidget, to_attr, default_backend_kwargs



class RasterWidget(BaseWidget):
"""
Plots spike train rasters.
Parameters
----------
sorting: SortingExtractor
The sorting extractor object
segment_index: None or int
The segment index.
unit_ids: list
List of unit ids
time_range: list
List with start time and end time
color: matplotlib color
The color to be used
"""

def __init__(
self, sorting, segment_index=None, unit_ids=None, time_range=None, color="k",
backend=None, **backend_kwargs
):


if segment_index is None:
if sorting.get_num_segments() != 1:
raise ValueError("You must provide segment_index=...")
segment_index = 0

if time_range is None:
frame_range = [0, sorting.to_spike_vector()[-1]["sample_index"]]
time_range = [f / sorting.sampling_frequency for f in frame_range]
else:
assert len(time_range) == 2, "'time_range' should be a list with start and end time in seconds"
frame_range = [int(t * sorting.sampling_frequency) for t in time_range]

plot_data = dict(
sorting=sorting,
segment_index=segment_index,
unit_ids=unit_ids,
color=color,
frame_range=frame_range,
time_range=time_range,
)
BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs)

def plot_matplotlib(self, data_plot, **backend_kwargs):
import matplotlib.pyplot as plt
from .utils_matplotlib import make_mpl_figure

dp = to_attr(data_plot)
sorting = dp.sorting

self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs)

units_ids = dp.unit_ids
if units_ids is None:
units_ids = sorting.unit_ids

with plt.rc_context({"axes.edgecolor": "gray"}):
for unit_index, unit_id in enumerate(units_ids):
spiketrain = sorting.get_unit_spike_train(
unit_id,
start_frame=dp.frame_range[0],
end_frame=dp.frame_range[1],
segment_index=dp.segment_index,
)
spiketimes = spiketrain / float(sorting.sampling_frequency)
self.ax.plot(
spiketimes,
unit_index * np.ones_like(spiketimes),
marker="|",
mew=1,
markersize=3,
ls="",
color=dp.color,
)
self.ax.set_yticks(np.arange(len(units_ids)))
self.ax.set_yticklabels(units_ids)
self.ax.set_xlim(*dp.time_range)
self.ax.set_xlabel("time (s)")







0 comments on commit 625ff5e

Please sign in to comment.