Skip to content

Commit

Permalink
Respond to Joe; add template_similarity check
Browse files Browse the repository at this point in the history
  • Loading branch information
chrishalcrow committed Oct 29, 2024
1 parent b62e02b commit 7c953c2
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 45 deletions.
8 changes: 4 additions & 4 deletions src/spikeinterface/exporters/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ def export_report(
**job_kwargs,
):
"""
Exports a SI spike sorting report. The report includes summary figures of the spike sorting output
(e.g. amplitude distributions, unit localization and depth VS amplitude) as well as unit-specific reports,
that include waveforms, templates, template maps, ISI distributions, and more.
Exports a SI spike sorting report. The report includes summary figures of the spike sorting output.
What is plotted depends on what has been calculated. Unit locations and unit waveforms are always included.
Unit waveform densities, correlograms and spike amplitudes are plotted if `waveforms`, `correlograms`
and 'template_similarity', and `spike_amplitudes` have been computed for the given `sorting_analyzer`.
Parameters
----------
Expand Down
82 changes: 41 additions & 41 deletions src/spikeinterface/widgets/unit_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
fig = self.figure
nrows = 2
ncols = 2
if sorting_analyzer.has_extension("correlograms") or sorting_analyzer.has_extension("spike_amplitudes"):
if sorting_analyzer.has_extension("correlograms") and sorting_analyzer.has_extension("template_similarity"):
ncols += 1
if sorting_analyzer.has_extension("waveforms"):
ncols += 1
Expand All @@ -117,31 +117,30 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
gs = fig.add_gridspec(nrows, ncols)
col_counter = 0

if sorting_analyzer.has_extension("unit_locations"):
ax1 = fig.add_subplot(gs[:2, col_counter])
# UnitLocationsPlotter().do_plot(dp.plot_data_unit_locations, ax=ax1)
w = UnitLocationsWidget(
sorting_analyzer,
unit_ids=[unit_id],
unit_colors=unit_colors,
plot_legend=False,
backend="matplotlib",
ax=ax1,
**unitlocationswidget_kwargs,
)
col_counter = col_counter + 1

unit_locations = sorting_analyzer.get_extension("unit_locations").get_data(outputs="by_unit")
unit_location = unit_locations[unit_id]
x, y = unit_location[0], unit_location[1]
ax1.set_xlim(x - 80, x + 80)
ax1.set_ylim(y - 250, y + 250)
ax1.set_xticks([])
ax1.set_xlabel(None)
ax1.set_ylabel(None)

ax2 = fig.add_subplot(gs[:2, col_counter])
w = UnitWaveformsWidget(
# Unit locations and unit waveform plots are always generated
ax_unit_locations = fig.add_subplot(gs[:2, col_counter])
_ = UnitLocationsWidget(
sorting_analyzer,
unit_ids=[unit_id],
unit_colors=unit_colors,
plot_legend=False,
backend="matplotlib",
ax=ax_unit_locations,
**unitlocationswidget_kwargs,
)
col_counter += 1

unit_locations = sorting_analyzer.get_extension("unit_locations").get_data(outputs="by_unit")
unit_location = unit_locations[unit_id]
x, y = unit_location[0], unit_location[1]
ax_unit_locations.set_xlim(x - 80, x + 80)
ax_unit_locations.set_ylim(y - 250, y + 250)
ax_unit_locations.set_xticks([])
ax_unit_locations.set_xlabel(None)
ax_unit_locations.set_ylabel(None)

ax_unit_waveforms = fig.add_subplot(gs[:2, col_counter])
_ = UnitWaveformsWidget(
sorting_analyzer,
unit_ids=[unit_id],
unit_colors=unit_colors,
Expand All @@ -151,46 +150,47 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
plot_legend=False,
sparsity=sparsity,
backend="matplotlib",
ax=ax2,
ax=ax_unit_waveforms,
**unitwaveformswidget_kwargs,
)
col_counter = col_counter + 1
col_counter += 1

ax2.set_title(None)
ax_unit_waveforms.set_title(None)

if sorting_analyzer.has_extension("waveforms"):
ax3 = fig.add_subplot(gs[:2, col_counter])
ax_waveform_density = fig.add_subplot(gs[:2, col_counter])
UnitWaveformDensityMapWidget(
sorting_analyzer,
unit_ids=[unit_id],
unit_colors=unit_colors,
use_max_channel=True,
same_axis=False,
backend="matplotlib",
ax=ax3,
ax=ax_waveform_density,
**unitwaveformdensitymapwidget_kwargs,
)
ax3.set_ylabel(None)
col_counter = col_counter + 1
col_counter += 1
ax_waveform_density.set_ylabel(None)

if sorting_analyzer.has_extension("correlograms"):
ax4 = fig.add_subplot(gs[:2, col_counter])
if sorting_analyzer.has_extension("correlograms") and sorting_analyzer.has_extension("template_similarity"):
ax_correlograms = fig.add_subplot(gs[:2, col_counter])
AutoCorrelogramsWidget(
sorting_analyzer,
unit_ids=[unit_id],
unit_colors=unit_colors,
backend="matplotlib",
ax=ax4,
ax=ax_correlograms,
**autocorrelogramswidget_kwargs,
)
col_counter += 1

ax4.set_title(None)
ax4.set_yticks([])
ax_correlograms.set_title(None)
ax_correlograms.set_yticks([])

if sorting_analyzer.has_extension("spike_amplitudes"):
ax5 = fig.add_subplot(gs[2, :col_counter])
ax6 = fig.add_subplot(gs[2, col_counter])
axes = np.array([ax5, ax6])
ax_spike_amps = fig.add_subplot(gs[2, : col_counter - 1])
ax_amps_distribution = fig.add_subplot(gs[2, col_counter - 1])
axes = np.array([ax_spike_amps, ax_amps_distribution])
AmplitudesWidget(
sorting_analyzer,
unit_ids=[unit_id],
Expand Down

0 comments on commit 7c953c2

Please sign in to comment.