From c1504f641fef0b7888c6c7d6afe63b7b4ab402c0 Mon Sep 17 00:00:00 2001 From: Jake Swann Date: Thu, 3 Oct 2024 13:39:10 -0400 Subject: [PATCH 01/13] Added vspacing_factor as a param for TracesWidget --- src/spikeinterface/widgets/traces.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index 86f2350a85..f5dadc780f 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -52,6 +52,8 @@ class TracesWidget(BaseWidget): If dict, keys should be the same as recording keys scale : float, default: 1 Scale factor for the traces + vspacing_factor : float, default: 1.5 + Vertical spacing between channels as a multiple of maximum channel amplitude with_colorbar : bool, default: True When mode is "map", a colorbar is added tile_size : int, default: 1500 @@ -82,6 +84,7 @@ def __init__( tile_size=1500, seconds_per_row=0.2, scale=1, + vspacing_factor=1.5, with_colorbar=True, add_legend=True, backend=None, @@ -168,7 +171,7 @@ def __init__( traces0 = list_traces[0] mean_channel_std = np.mean(np.std(traces0, axis=0)) max_channel_amp = np.max(np.max(np.abs(traces0), axis=0)) - vspacing = max_channel_amp * 1.5 + vspacing = max_channel_amp * vspacing_factor if rec0.get_channel_groups() is None: color_groups = False From e2fb8cc6985ab51c4d599122ab9643071bd950b9 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 15 Oct 2024 14:31:25 -0600 Subject: [PATCH 02/13] cap python --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index a43ab63c8e..e535747428 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ authors = [ ] description = "Python toolkit for analysis, visualization, and comparison of spike sorting output" readme = "README.md" -requires-python = ">=3.9,<4.0" +requires-python = ">=3.9,<3.14" # Only numpy 2.0 supported on python 3.12 for windows. We need to wait for fix on neo classifiers = [ "Programming Language :: Python :: 3 :: Only", "License :: OSI Approved :: MIT License", From 5347cba25a67566f0adc46b3dee8e24bfd1a545a Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 15 Oct 2024 14:55:52 -0600 Subject: [PATCH 03/13] Update pyproject.toml Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index e535747428..403988c980 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ authors = [ ] description = "Python toolkit for analysis, visualization, and comparison of spike sorting output" readme = "README.md" -requires-python = ">=3.9,<3.14" # Only numpy 2.0 supported on python 3.12 for windows. We need to wait for fix on neo +requires-python = ">=3.9,<3.13" # Only numpy 2.0 supported on python 3.13 for windows. We need to wait for fix on neo classifiers = [ "Programming Language :: Python :: 3 :: Only", "License :: OSI Approved :: MIT License", From 146d34a58f29165307bd5e00af7e6c4fbb8d2306 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Tue, 15 Oct 2024 17:16:02 -0400 Subject: [PATCH 04/13] remove writing text_file --- src/spikeinterface/sorters/basesorter.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/sorters/basesorter.py b/src/spikeinterface/sorters/basesorter.py index 3502d27548..28948f81cc 100644 --- a/src/spikeinterface/sorters/basesorter.py +++ b/src/spikeinterface/sorters/basesorter.py @@ -145,9 +145,8 @@ def initialize_folder(cls, recording, output_folder, verbose, remove_existing_fo elif recording.check_serializability("pickle"): recording.dump(output_folder / "spikeinterface_recording.pickle", relative_to=output_folder) else: - # TODO: deprecate and finally remove this after 0.100 - d = {"warning": "The recording is not serializable to json"} - rec_file.write_text(json.dumps(d, indent=4), encoding="utf8") + raise RuntimeError("This recording is not serializable and so can not be sorted. Consider `recording.save()` to save a " + "compatible binary file.") return output_folder From 5e832c9523a218a032c75d774c71c9188a8114ae Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 15 Oct 2024 21:19:30 +0000 Subject: [PATCH 05/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sorters/basesorter.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sorters/basesorter.py b/src/spikeinterface/sorters/basesorter.py index 28948f81cc..c59fa29c05 100644 --- a/src/spikeinterface/sorters/basesorter.py +++ b/src/spikeinterface/sorters/basesorter.py @@ -145,8 +145,10 @@ def initialize_folder(cls, recording, output_folder, verbose, remove_existing_fo elif recording.check_serializability("pickle"): recording.dump(output_folder / "spikeinterface_recording.pickle", relative_to=output_folder) else: - raise RuntimeError("This recording is not serializable and so can not be sorted. Consider `recording.save()` to save a " - "compatible binary file.") + raise RuntimeError( + "This recording is not serializable and so can not be sorted. Consider `recording.save()` to save a " + "compatible binary file." + ) return output_folder From b62e02ba244c40f9eb86e2206bc844e2be339a2d Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Mon, 21 Oct 2024 10:39:37 +0100 Subject: [PATCH 06/13] export_report without waveforms --- src/spikeinterface/widgets/unit_summary.py | 44 +++++++++++++--------- 1 file changed, 26 insertions(+), 18 deletions(-) diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index 755e60ccbf..8aea6fd690 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -107,15 +107,18 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): # and use custum grid spec fig = self.figure nrows = 2 - ncols = 3 + ncols = 2 if sorting_analyzer.has_extension("correlograms") or sorting_analyzer.has_extension("spike_amplitudes"): ncols += 1 + if sorting_analyzer.has_extension("waveforms"): + ncols += 1 if sorting_analyzer.has_extension("spike_amplitudes"): nrows += 1 gs = fig.add_gridspec(nrows, ncols) + col_counter = 0 if sorting_analyzer.has_extension("unit_locations"): - ax1 = fig.add_subplot(gs[:2, 0]) + ax1 = fig.add_subplot(gs[:2, col_counter]) # UnitLocationsPlotter().do_plot(dp.plot_data_unit_locations, ax=ax1) w = UnitLocationsWidget( sorting_analyzer, @@ -126,6 +129,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax=ax1, **unitlocationswidget_kwargs, ) + col_counter = col_counter + 1 unit_locations = sorting_analyzer.get_extension("unit_locations").get_data(outputs="by_unit") unit_location = unit_locations[unit_id] @@ -136,12 +140,13 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax1.set_xlabel(None) ax1.set_ylabel(None) - ax2 = fig.add_subplot(gs[:2, 1]) + ax2 = fig.add_subplot(gs[:2, col_counter]) w = UnitWaveformsWidget( sorting_analyzer, unit_ids=[unit_id], unit_colors=unit_colors, plot_templates=True, + plot_waveforms=sorting_analyzer.has_extension("waveforms"), same_axis=True, plot_legend=False, sparsity=sparsity, @@ -149,24 +154,27 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax=ax2, **unitwaveformswidget_kwargs, ) + col_counter = col_counter + 1 ax2.set_title(None) - ax3 = fig.add_subplot(gs[:2, 2]) - UnitWaveformDensityMapWidget( - sorting_analyzer, - unit_ids=[unit_id], - unit_colors=unit_colors, - use_max_channel=True, - same_axis=False, - backend="matplotlib", - ax=ax3, - **unitwaveformdensitymapwidget_kwargs, - ) - ax3.set_ylabel(None) + if sorting_analyzer.has_extension("waveforms"): + ax3 = fig.add_subplot(gs[:2, col_counter]) + UnitWaveformDensityMapWidget( + sorting_analyzer, + unit_ids=[unit_id], + unit_colors=unit_colors, + use_max_channel=True, + same_axis=False, + backend="matplotlib", + ax=ax3, + **unitwaveformdensitymapwidget_kwargs, + ) + ax3.set_ylabel(None) + col_counter = col_counter + 1 if sorting_analyzer.has_extension("correlograms"): - ax4 = fig.add_subplot(gs[:2, 3]) + ax4 = fig.add_subplot(gs[:2, col_counter]) AutoCorrelogramsWidget( sorting_analyzer, unit_ids=[unit_id], @@ -180,8 +188,8 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax4.set_yticks([]) if sorting_analyzer.has_extension("spike_amplitudes"): - ax5 = fig.add_subplot(gs[2, :3]) - ax6 = fig.add_subplot(gs[2, 3]) + ax5 = fig.add_subplot(gs[2, :col_counter]) + ax6 = fig.add_subplot(gs[2, col_counter]) axes = np.array([ax5, ax6]) AmplitudesWidget( sorting_analyzer, From ebc6164064c72490cc1b9f539b1b2eb5f0eae9ea Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Mon, 21 Oct 2024 09:12:37 -0600 Subject: [PATCH 07/13] Update pyproject.toml Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 403988c980..fc09ad9198 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ authors = [ ] description = "Python toolkit for analysis, visualization, and comparison of spike sorting output" readme = "README.md" -requires-python = ">=3.9,<3.13" # Only numpy 2.0 supported on python 3.13 for windows. We need to wait for fix on neo +requires-python = ">=3.9,<3.13" # Only numpy 2.1 supported on python 3.13 for windows. We need to wait for fix on neo classifiers = [ "Programming Language :: Python :: 3 :: Only", "License :: OSI Approved :: MIT License", From b5bd2fb1f4459689e05893a0524a8dab3543b5e8 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Wed, 23 Oct 2024 09:21:01 -0400 Subject: [PATCH 08/13] add error messaging around use of get data in templates --- .../core/analyzer_extension_core.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index bc5de63d07..e0b267ae72 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -380,7 +380,12 @@ def _set_params(self, ms_before: float = 1.0, ms_after: float = 2.0, operators=N assert isinstance(operators, list) for operator in operators: if isinstance(operator, str): - assert operator in ("average", "std", "median", "mad") + if operator not in ("average", "std", "median", "mad"): + error_msg = ( + f"You have entered an operator {operator} in your `operators` argument which is " + f"not supported. Please use any of ['average', 'std', 'median', 'mad'] instead." + ) + raise ValueError(error_msg) else: assert isinstance(operator, (list, tuple)) assert len(operator) == 2 @@ -549,9 +554,17 @@ def _get_data(self, operator="average", percentile=None, outputs="numpy"): if operator != "percentile": key = operator else: - assert percentile is not None, "You must provide percentile=..." + assert percentile is not None, "You must provide percentile=... if `operator=percentile`" key = f"percentile_{percentile}" + if key not in self.data.keys(): + error_msg = ( + f"You have entered `operator={key}`, but the only operators calculated are " + f"{list(self.data.keys())}. Please use one of these as your `operator` in the " + f"`get_data` function." + ) + raise ValueError(error_msg) + templates_array = self.data[key] if outputs == "numpy": From 22d19d52217594a129fc474bde119a2370e8d0e4 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Wed, 23 Oct 2024 09:36:17 -0400 Subject: [PATCH 09/13] more docs stuff --- .../core/analyzer_extension_core.py | 62 ++++++++++--------- 1 file changed, 34 insertions(+), 28 deletions(-) diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index e0b267ae72..7d644c9c00 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -22,21 +22,23 @@ class ComputeRandomSpikes(AnalyzerExtension): """ - AnalyzerExtension that select some random spikes. + AnalyzerExtension that select somes random spikes. + This is allows for a subsampling of spikes for further calculations and is important + for managing that amount of memory and speed of computation in the analyzer. This will be used by the `waveforms`/`templates` extensions. - This internally use `random_spikes_selection()` parameters are the same. + This internally use `random_spikes_selection()` parameters. Parameters ---------- - method: "uniform" | "all", default: "uniform" + method : "uniform" | "all", default: "uniform" The method to select the spikes - max_spikes_per_unit: int, default: 500 + max_spikes_per_unit : int, default: 500 The maximum number of spikes per unit, ignored if method="all" - margin_size: int, default: None + margin_size : int, default: None A margin on each border of segments to avoid border spikes, ignored if method="all" - seed: int or None, default: None + seed : int or None, default: None A seed for the random generator, ignored if method="all" Returns @@ -104,7 +106,7 @@ def get_random_spikes(self): return self._some_spikes def get_selected_indices_in_spike_train(self, unit_id, segment_index): - # usefull for Waveforms extractor backwars compatibility + # useful for Waveforms extractor backwars compatibility # In Waveforms extractor "selected_spikes" was a dict (key: unit_id) of list (segment_index) of indices of spikes in spiketrain sorting = self.sorting_analyzer.sorting random_spikes_indices = self.data["random_spikes_indices"] @@ -133,16 +135,16 @@ class ComputeWaveforms(AnalyzerExtension): Parameters ---------- - ms_before: float, default: 1.0 + ms_before : float, default: 1.0 The number of ms to extract before the spike events - ms_after: float, default: 2.0 + ms_after : float, default: 2.0 The number of ms to extract after the spike events - dtype: None | dtype, default: None + dtype : None | dtype, default: None The dtype of the waveforms. If None, the dtype of the recording is used. Returns ------- - waveforms: np.ndarray + waveforms : np.ndarray Array with computed waveforms with shape (num_random_spikes, num_samples, num_channels) """ @@ -410,9 +412,13 @@ def _run(self, verbose=False, **job_kwargs): self._compute_and_append_from_waveforms(self.params["operators"]) else: - for operator in self.params["operators"]: - if operator not in ("average", "std"): - raise ValueError(f"Computing templates with operators {operator} needs the 'waveforms' extension") + bad_operator_list = [ + operator for operator in self.params["operators"] if operator not in ("average", "std") + ] + if len(bad_operator_list) > 0: + raise ValueError( + f"Computing templates with operators {bad_operator_list} requires the 'waveforms' extension" + ) recording = self.sorting_analyzer.recording sorting = self.sorting_analyzer.sorting @@ -446,7 +452,7 @@ def _run(self, verbose=False, **job_kwargs): def _compute_and_append_from_waveforms(self, operators): if not self.sorting_analyzer.has_extension("waveforms"): - raise ValueError(f"Computing templates with operators {operators} needs the 'waveforms' extension") + raise ValueError(f"Computing templates with operators {operators} requires the 'waveforms' extension") unit_ids = self.sorting_analyzer.unit_ids channel_ids = self.sorting_analyzer.channel_ids @@ -471,7 +477,7 @@ def _compute_and_append_from_waveforms(self, operators): assert self.sorting_analyzer.has_extension( "random_spikes" - ), "compute templates requires the random_spikes extension. You can run sorting_analyzer.get_random_spikes()" + ), "compute 'templates' requires the random_spikes extension. You can run sorting_analyzer.compute('random_spikes')" some_spikes = self.sorting_analyzer.get_extension("random_spikes").get_random_spikes() for unit_index, unit_id in enumerate(unit_ids): spike_mask = some_spikes["unit_index"] == unit_index @@ -579,7 +585,7 @@ def _get_data(self, operator="average", percentile=None, outputs="numpy"): probe=self.sorting_analyzer.get_probe(), ) else: - raise ValueError("outputs must be numpy or Templates") + raise ValueError("outputs must be `numpy` or `Templates`") def get_templates(self, unit_ids=None, operator="average", percentile=None, save=True, outputs="numpy"): """ @@ -589,26 +595,26 @@ def get_templates(self, unit_ids=None, operator="average", percentile=None, save Parameters ---------- - unit_ids: list or None + unit_ids : list or None Unit ids to retrieve waveforms for - operator: "average" | "median" | "std" | "percentile", default: "average" + operator : "average" | "median" | "std" | "percentile", default: "average" The operator to compute the templates - percentile: float, default: None + percentile : float, default: None Percentile to use for operator="percentile" - save: bool, default True + save : bool, default: True In case, the operator is not computed yet it can be saved to folder or zarr - outputs: "numpy" | "Templates" + outputs : "numpy" | "Templates", default: "numpy" Whether to return a numpy array or a Templates object Returns ------- - templates: np.array + templates : np.array | Templates The returned templates (num_units, num_samples, num_channels) """ if operator != "percentile": key = operator else: - assert percentile is not None, "You must provide percentile=..." + assert percentile is not None, "You must provide percentile=... if `operator='percentile'`" key = f"pencentile_{percentile}" if key in self.data: @@ -645,7 +651,7 @@ def get_templates(self, unit_ids=None, operator="average", percentile=None, save is_scaled=self.sorting_analyzer.return_scaled, ) else: - raise ValueError("outputs must be numpy or Templates") + raise ValueError("`outputs` must be 'numpy' or 'Templates'") def get_unit_template(self, unit_id, operator="average"): """ @@ -655,7 +661,7 @@ def get_unit_template(self, unit_id, operator="average"): ---------- unit_id: str | int Unit id to retrieve waveforms for - operator: str + operator: str, default: "average" The operator to compute the templates Returns @@ -713,13 +719,13 @@ def _set_params(self, num_chunks_per_segment=20, chunk_size=10000, seed=None): return params def _select_extension_data(self, unit_ids): - # this do not depend on units + # this does not depend on units return self.data def _merge_extension_data( self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs ): - # this do not depend on units + # this does not depend on units return self.data.copy() def _run(self, verbose=False): From b1f11fbdcf3a9e94e77030f7da83368ed9ccea73 Mon Sep 17 00:00:00 2001 From: Zach McKenzie <92116279+zm711@users.noreply.github.com> Date: Wed, 23 Oct 2024 09:38:03 -0400 Subject: [PATCH 10/13] fix typo --- src/spikeinterface/core/analyzer_extension_core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index 7d644c9c00..2d9924554c 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -23,7 +23,7 @@ class ComputeRandomSpikes(AnalyzerExtension): """ AnalyzerExtension that select somes random spikes. - This is allows for a subsampling of spikes for further calculations and is important + This allows for a subsampling of spikes for further calculations and is important for managing that amount of memory and speed of computation in the analyzer. This will be used by the `waveforms`/`templates` extensions. From 3406f85ba35209dd557ca9c0b0c15c5c84219e7a Mon Sep 17 00:00:00 2001 From: Zach McKenzie <92116279+zm711@users.noreply.github.com> Date: Fri, 25 Oct 2024 06:39:35 -0400 Subject: [PATCH 11/13] Joe's comments Co-authored-by: Joe Ziminski <55797454+JoeZiminski@users.noreply.github.com> --- src/spikeinterface/core/analyzer_extension_core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index 2d9924554c..55e0e34dcc 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -28,7 +28,7 @@ class ComputeRandomSpikes(AnalyzerExtension): This will be used by the `waveforms`/`templates` extensions. - This internally use `random_spikes_selection()` parameters. + This internally uses `random_spikes_selection()` parameters. Parameters ---------- @@ -106,7 +106,7 @@ def get_random_spikes(self): return self._some_spikes def get_selected_indices_in_spike_train(self, unit_id, segment_index): - # useful for Waveforms extractor backwars compatibility + # useful for WaveformExtractor backwards compatibility # In Waveforms extractor "selected_spikes" was a dict (key: unit_id) of list (segment_index) of indices of spikes in spiketrain sorting = self.sorting_analyzer.sorting random_spikes_indices = self.data["random_spikes_indices"] From 7c953c2b1347dfbbb4058f5a9b7462f90d22c1dd Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Tue, 29 Oct 2024 12:19:26 +0000 Subject: [PATCH 12/13] Respond to Joe; add template_similarity check --- src/spikeinterface/exporters/report.py | 8 +-- src/spikeinterface/widgets/unit_summary.py | 82 +++++++++++----------- 2 files changed, 45 insertions(+), 45 deletions(-) diff --git a/src/spikeinterface/exporters/report.py b/src/spikeinterface/exporters/report.py index 3a4be9213a..484da83342 100644 --- a/src/spikeinterface/exporters/report.py +++ b/src/spikeinterface/exporters/report.py @@ -20,10 +20,10 @@ def export_report( **job_kwargs, ): """ - Exports a SI spike sorting report. The report includes summary figures of the spike sorting output - (e.g. amplitude distributions, unit localization and depth VS amplitude) as well as unit-specific reports, - that include waveforms, templates, template maps, ISI distributions, and more. - + Exports a SI spike sorting report. The report includes summary figures of the spike sorting output. + What is plotted depends on what has been calculated. Unit locations and unit waveforms are always included. + Unit waveform densities, correlograms and spike amplitudes are plotted if `waveforms`, `correlograms` + and 'template_similarity', and `spike_amplitudes` have been computed for the given `sorting_analyzer`. Parameters ---------- diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index 8aea6fd690..d8cbeb7bb3 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -108,7 +108,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): fig = self.figure nrows = 2 ncols = 2 - if sorting_analyzer.has_extension("correlograms") or sorting_analyzer.has_extension("spike_amplitudes"): + if sorting_analyzer.has_extension("correlograms") and sorting_analyzer.has_extension("template_similarity"): ncols += 1 if sorting_analyzer.has_extension("waveforms"): ncols += 1 @@ -117,31 +117,30 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): gs = fig.add_gridspec(nrows, ncols) col_counter = 0 - if sorting_analyzer.has_extension("unit_locations"): - ax1 = fig.add_subplot(gs[:2, col_counter]) - # UnitLocationsPlotter().do_plot(dp.plot_data_unit_locations, ax=ax1) - w = UnitLocationsWidget( - sorting_analyzer, - unit_ids=[unit_id], - unit_colors=unit_colors, - plot_legend=False, - backend="matplotlib", - ax=ax1, - **unitlocationswidget_kwargs, - ) - col_counter = col_counter + 1 - - unit_locations = sorting_analyzer.get_extension("unit_locations").get_data(outputs="by_unit") - unit_location = unit_locations[unit_id] - x, y = unit_location[0], unit_location[1] - ax1.set_xlim(x - 80, x + 80) - ax1.set_ylim(y - 250, y + 250) - ax1.set_xticks([]) - ax1.set_xlabel(None) - ax1.set_ylabel(None) - - ax2 = fig.add_subplot(gs[:2, col_counter]) - w = UnitWaveformsWidget( + # Unit locations and unit waveform plots are always generated + ax_unit_locations = fig.add_subplot(gs[:2, col_counter]) + _ = UnitLocationsWidget( + sorting_analyzer, + unit_ids=[unit_id], + unit_colors=unit_colors, + plot_legend=False, + backend="matplotlib", + ax=ax_unit_locations, + **unitlocationswidget_kwargs, + ) + col_counter += 1 + + unit_locations = sorting_analyzer.get_extension("unit_locations").get_data(outputs="by_unit") + unit_location = unit_locations[unit_id] + x, y = unit_location[0], unit_location[1] + ax_unit_locations.set_xlim(x - 80, x + 80) + ax_unit_locations.set_ylim(y - 250, y + 250) + ax_unit_locations.set_xticks([]) + ax_unit_locations.set_xlabel(None) + ax_unit_locations.set_ylabel(None) + + ax_unit_waveforms = fig.add_subplot(gs[:2, col_counter]) + _ = UnitWaveformsWidget( sorting_analyzer, unit_ids=[unit_id], unit_colors=unit_colors, @@ -151,15 +150,15 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): plot_legend=False, sparsity=sparsity, backend="matplotlib", - ax=ax2, + ax=ax_unit_waveforms, **unitwaveformswidget_kwargs, ) - col_counter = col_counter + 1 + col_counter += 1 - ax2.set_title(None) + ax_unit_waveforms.set_title(None) if sorting_analyzer.has_extension("waveforms"): - ax3 = fig.add_subplot(gs[:2, col_counter]) + ax_waveform_density = fig.add_subplot(gs[:2, col_counter]) UnitWaveformDensityMapWidget( sorting_analyzer, unit_ids=[unit_id], @@ -167,30 +166,31 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): use_max_channel=True, same_axis=False, backend="matplotlib", - ax=ax3, + ax=ax_waveform_density, **unitwaveformdensitymapwidget_kwargs, ) - ax3.set_ylabel(None) - col_counter = col_counter + 1 + col_counter += 1 + ax_waveform_density.set_ylabel(None) - if sorting_analyzer.has_extension("correlograms"): - ax4 = fig.add_subplot(gs[:2, col_counter]) + if sorting_analyzer.has_extension("correlograms") and sorting_analyzer.has_extension("template_similarity"): + ax_correlograms = fig.add_subplot(gs[:2, col_counter]) AutoCorrelogramsWidget( sorting_analyzer, unit_ids=[unit_id], unit_colors=unit_colors, backend="matplotlib", - ax=ax4, + ax=ax_correlograms, **autocorrelogramswidget_kwargs, ) + col_counter += 1 - ax4.set_title(None) - ax4.set_yticks([]) + ax_correlograms.set_title(None) + ax_correlograms.set_yticks([]) if sorting_analyzer.has_extension("spike_amplitudes"): - ax5 = fig.add_subplot(gs[2, :col_counter]) - ax6 = fig.add_subplot(gs[2, col_counter]) - axes = np.array([ax5, ax6]) + ax_spike_amps = fig.add_subplot(gs[2, : col_counter - 1]) + ax_amps_distribution = fig.add_subplot(gs[2, col_counter - 1]) + axes = np.array([ax_spike_amps, ax_amps_distribution]) AmplitudesWidget( sorting_analyzer, unit_ids=[unit_id], From 2faed131014fe46fd0fafddcb9b94872f889ca7c Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 4 Nov 2024 16:05:42 +0100 Subject: [PATCH 13/13] Remove the need of template_similarity extension for autocorrelogram plot --- src/spikeinterface/exporters/report.py | 4 ++-- src/spikeinterface/widgets/autocorrelograms.py | 8 +++++++- src/spikeinterface/widgets/crosscorrelograms.py | 3 ++- src/spikeinterface/widgets/unit_summary.py | 4 ++-- 4 files changed, 13 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/exporters/report.py b/src/spikeinterface/exporters/report.py index 484da83342..ab08401382 100644 --- a/src/spikeinterface/exporters/report.py +++ b/src/spikeinterface/exporters/report.py @@ -22,8 +22,8 @@ def export_report( """ Exports a SI spike sorting report. The report includes summary figures of the spike sorting output. What is plotted depends on what has been calculated. Unit locations and unit waveforms are always included. - Unit waveform densities, correlograms and spike amplitudes are plotted if `waveforms`, `correlograms` - and 'template_similarity', and `spike_amplitudes` have been computed for the given `sorting_analyzer`. + Unit waveform densities, correlograms and spike amplitudes are plotted if `waveforms`, `correlograms`, + and `spike_amplitudes` have been computed for the given `sorting_analyzer`. Parameters ---------- diff --git a/src/spikeinterface/widgets/autocorrelograms.py b/src/spikeinterface/widgets/autocorrelograms.py index c8acd93dc2..c211a277f8 100644 --- a/src/spikeinterface/widgets/autocorrelograms.py +++ b/src/spikeinterface/widgets/autocorrelograms.py @@ -9,7 +9,13 @@ class AutoCorrelogramsWidget(CrossCorrelogramsWidget): # the doc is copied form CrossCorrelogramsWidget def __init__(self, *args, **kargs): - CrossCorrelogramsWidget.__init__(self, *args, **kargs) + _ = kargs.pop("min_similarity_for_correlograms", 0.0) + CrossCorrelogramsWidget.__init__( + self, + *args, + **kargs, + min_similarity_for_correlograms=None, + ) def plot_matplotlib(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt diff --git a/src/spikeinterface/widgets/crosscorrelograms.py b/src/spikeinterface/widgets/crosscorrelograms.py index cdb2041aa3..88dd803323 100644 --- a/src/spikeinterface/widgets/crosscorrelograms.py +++ b/src/spikeinterface/widgets/crosscorrelograms.py @@ -21,7 +21,8 @@ class CrossCorrelogramsWidget(BaseWidget): List of unit ids min_similarity_for_correlograms : float, default: 0.2 For sortingview backend. Threshold for computing pair-wise cross-correlograms. - If template similarity between two units is below this threshold, the cross-correlogram is not displayed + If template similarity between two units is below this threshold, the cross-correlogram is not displayed. + For auto-correlograms plot, this is automatically set to None. window_ms : float, default: 100.0 Window for CCGs in ms. If correlograms are already computed (e.g. with SortingAnalyzer), this argument is ignored diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index d8cbeb7bb3..9466110110 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -108,7 +108,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): fig = self.figure nrows = 2 ncols = 2 - if sorting_analyzer.has_extension("correlograms") and sorting_analyzer.has_extension("template_similarity"): + if sorting_analyzer.has_extension("correlograms"): ncols += 1 if sorting_analyzer.has_extension("waveforms"): ncols += 1 @@ -172,7 +172,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): col_counter += 1 ax_waveform_density.set_ylabel(None) - if sorting_analyzer.has_extension("correlograms") and sorting_analyzer.has_extension("template_similarity"): + if sorting_analyzer.has_extension("correlograms"): ax_correlograms = fig.add_subplot(gs[:2, col_counter]) AutoCorrelogramsWidget( sorting_analyzer,