Skip to content

Commit

Permalink
Merge pull request #3493 from chrishalcrow/report_without_waveforms
Browse files Browse the repository at this point in the history
Allow `export_report` to run without waveforms
  • Loading branch information
alejoe91 authored Nov 4, 2024
2 parents 5f81566 + d337001 commit 0baf7e0
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 53 deletions.
8 changes: 4 additions & 4 deletions src/spikeinterface/exporters/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ def export_report(
**job_kwargs,
):
"""
Exports a SI spike sorting report. The report includes summary figures of the spike sorting output
(e.g. amplitude distributions, unit localization and depth VS amplitude) as well as unit-specific reports,
that include waveforms, templates, template maps, ISI distributions, and more.
Exports a SI spike sorting report. The report includes summary figures of the spike sorting output.
What is plotted depends on what has been calculated. Unit locations and unit waveforms are always included.
Unit waveform densities, correlograms and spike amplitudes are plotted if `waveforms`, `correlograms`,
and `spike_amplitudes` have been computed for the given `sorting_analyzer`.
Parameters
----------
Expand Down
8 changes: 7 additions & 1 deletion src/spikeinterface/widgets/autocorrelograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,13 @@ class AutoCorrelogramsWidget(CrossCorrelogramsWidget):
# the doc is copied form CrossCorrelogramsWidget

def __init__(self, *args, **kargs):
CrossCorrelogramsWidget.__init__(self, *args, **kargs)
_ = kargs.pop("min_similarity_for_correlograms", 0.0)
CrossCorrelogramsWidget.__init__(
self,
*args,
**kargs,
min_similarity_for_correlograms=None,
)

def plot_matplotlib(self, data_plot, **backend_kwargs):
import matplotlib.pyplot as plt
Expand Down
3 changes: 2 additions & 1 deletion src/spikeinterface/widgets/crosscorrelograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ class CrossCorrelogramsWidget(BaseWidget):
List of unit ids
min_similarity_for_correlograms : float, default: 0.2
For sortingview backend. Threshold for computing pair-wise cross-correlograms.
If template similarity between two units is below this threshold, the cross-correlogram is not displayed
If template similarity between two units is below this threshold, the cross-correlogram is not displayed.
For auto-correlograms plot, this is automatically set to None.
window_ms : float, default: 100.0
Window for CCGs in ms. If correlograms are already computed (e.g. with SortingAnalyzer),
this argument is ignored
Expand Down
102 changes: 55 additions & 47 deletions src/spikeinterface/widgets/unit_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,82 +107,90 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
# and use custum grid spec
fig = self.figure
nrows = 2
ncols = 3
if sorting_analyzer.has_extension("correlograms") or sorting_analyzer.has_extension("spike_amplitudes"):
ncols = 2
if sorting_analyzer.has_extension("correlograms"):
ncols += 1
if sorting_analyzer.has_extension("waveforms"):
ncols += 1
if sorting_analyzer.has_extension("spike_amplitudes"):
nrows += 1
gs = fig.add_gridspec(nrows, ncols)
col_counter = 0

if sorting_analyzer.has_extension("unit_locations"):
ax1 = fig.add_subplot(gs[:2, 0])
# UnitLocationsPlotter().do_plot(dp.plot_data_unit_locations, ax=ax1)
w = UnitLocationsWidget(
sorting_analyzer,
unit_ids=[unit_id],
unit_colors=unit_colors,
plot_legend=False,
backend="matplotlib",
ax=ax1,
**unitlocationswidget_kwargs,
)

unit_locations = sorting_analyzer.get_extension("unit_locations").get_data(outputs="by_unit")
unit_location = unit_locations[unit_id]
x, y = unit_location[0], unit_location[1]
ax1.set_xlim(x - 80, x + 80)
ax1.set_ylim(y - 250, y + 250)
ax1.set_xticks([])
ax1.set_xlabel(None)
ax1.set_ylabel(None)

ax2 = fig.add_subplot(gs[:2, 1])
w = UnitWaveformsWidget(
# Unit locations and unit waveform plots are always generated
ax_unit_locations = fig.add_subplot(gs[:2, col_counter])
_ = UnitLocationsWidget(
sorting_analyzer,
unit_ids=[unit_id],
unit_colors=unit_colors,
plot_legend=False,
backend="matplotlib",
ax=ax_unit_locations,
**unitlocationswidget_kwargs,
)
col_counter += 1

unit_locations = sorting_analyzer.get_extension("unit_locations").get_data(outputs="by_unit")
unit_location = unit_locations[unit_id]
x, y = unit_location[0], unit_location[1]
ax_unit_locations.set_xlim(x - 80, x + 80)
ax_unit_locations.set_ylim(y - 250, y + 250)
ax_unit_locations.set_xticks([])
ax_unit_locations.set_xlabel(None)
ax_unit_locations.set_ylabel(None)

ax_unit_waveforms = fig.add_subplot(gs[:2, col_counter])
_ = UnitWaveformsWidget(
sorting_analyzer,
unit_ids=[unit_id],
unit_colors=unit_colors,
plot_templates=True,
plot_waveforms=sorting_analyzer.has_extension("waveforms"),
same_axis=True,
plot_legend=False,
sparsity=sparsity,
backend="matplotlib",
ax=ax2,
ax=ax_unit_waveforms,
**unitwaveformswidget_kwargs,
)
col_counter += 1

ax2.set_title(None)
ax_unit_waveforms.set_title(None)

ax3 = fig.add_subplot(gs[:2, 2])
UnitWaveformDensityMapWidget(
sorting_analyzer,
unit_ids=[unit_id],
unit_colors=unit_colors,
use_max_channel=True,
same_axis=False,
backend="matplotlib",
ax=ax3,
**unitwaveformdensitymapwidget_kwargs,
)
ax3.set_ylabel(None)
if sorting_analyzer.has_extension("waveforms"):
ax_waveform_density = fig.add_subplot(gs[:2, col_counter])
UnitWaveformDensityMapWidget(
sorting_analyzer,
unit_ids=[unit_id],
unit_colors=unit_colors,
use_max_channel=True,
same_axis=False,
backend="matplotlib",
ax=ax_waveform_density,
**unitwaveformdensitymapwidget_kwargs,
)
col_counter += 1
ax_waveform_density.set_ylabel(None)

if sorting_analyzer.has_extension("correlograms"):
ax4 = fig.add_subplot(gs[:2, 3])
ax_correlograms = fig.add_subplot(gs[:2, col_counter])
AutoCorrelogramsWidget(
sorting_analyzer,
unit_ids=[unit_id],
unit_colors=unit_colors,
backend="matplotlib",
ax=ax4,
ax=ax_correlograms,
**autocorrelogramswidget_kwargs,
)
col_counter += 1

ax4.set_title(None)
ax4.set_yticks([])
ax_correlograms.set_title(None)
ax_correlograms.set_yticks([])

if sorting_analyzer.has_extension("spike_amplitudes"):
ax5 = fig.add_subplot(gs[2, :3])
ax6 = fig.add_subplot(gs[2, 3])
axes = np.array([ax5, ax6])
ax_spike_amps = fig.add_subplot(gs[2, : col_counter - 1])
ax_amps_distribution = fig.add_subplot(gs[2, col_counter - 1])
axes = np.array([ax_spike_amps, ax_amps_distribution])
AmplitudesWidget(
sorting_analyzer,
unit_ids=[unit_id],
Expand Down

0 comments on commit 0baf7e0

Please sign in to comment.