From c1504f641fef0b7888c6c7d6afe63b7b4ab402c0 Mon Sep 17 00:00:00 2001 From: Jake Swann Date: Thu, 3 Oct 2024 13:39:10 -0400 Subject: [PATCH 01/42] Added vspacing_factor as a param for TracesWidget --- src/spikeinterface/widgets/traces.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index 86f2350a85..f5dadc780f 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -52,6 +52,8 @@ class TracesWidget(BaseWidget): If dict, keys should be the same as recording keys scale : float, default: 1 Scale factor for the traces + vspacing_factor : float, default: 1.5 + Vertical spacing between channels as a multiple of maximum channel amplitude with_colorbar : bool, default: True When mode is "map", a colorbar is added tile_size : int, default: 1500 @@ -82,6 +84,7 @@ def __init__( tile_size=1500, seconds_per_row=0.2, scale=1, + vspacing_factor=1.5, with_colorbar=True, add_legend=True, backend=None, @@ -168,7 +171,7 @@ def __init__( traces0 = list_traces[0] mean_channel_std = np.mean(np.std(traces0, axis=0)) max_channel_amp = np.max(np.max(np.abs(traces0), axis=0)) - vspacing = max_channel_amp * 1.5 + vspacing = max_channel_amp * vspacing_factor if rec0.get_channel_groups() is None: color_groups = False From e2fb8cc6985ab51c4d599122ab9643071bd950b9 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 15 Oct 2024 14:31:25 -0600 Subject: [PATCH 02/42] cap python --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index a43ab63c8e..e535747428 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ authors = [ ] description = "Python toolkit for analysis, visualization, and comparison of spike sorting output" readme = "README.md" -requires-python = ">=3.9,<4.0" +requires-python = ">=3.9,<3.14" # Only numpy 2.0 supported on python 3.12 for windows. We need to wait for fix on neo classifiers = [ "Programming Language :: Python :: 3 :: Only", "License :: OSI Approved :: MIT License", From 5347cba25a67566f0adc46b3dee8e24bfd1a545a Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 15 Oct 2024 14:55:52 -0600 Subject: [PATCH 03/42] Update pyproject.toml Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index e535747428..403988c980 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ authors = [ ] description = "Python toolkit for analysis, visualization, and comparison of spike sorting output" readme = "README.md" -requires-python = ">=3.9,<3.14" # Only numpy 2.0 supported on python 3.12 for windows. We need to wait for fix on neo +requires-python = ">=3.9,<3.13" # Only numpy 2.0 supported on python 3.13 for windows. We need to wait for fix on neo classifiers = [ "Programming Language :: Python :: 3 :: Only", "License :: OSI Approved :: MIT License", From 146d34a58f29165307bd5e00af7e6c4fbb8d2306 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Tue, 15 Oct 2024 17:16:02 -0400 Subject: [PATCH 04/42] remove writing text_file --- src/spikeinterface/sorters/basesorter.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/sorters/basesorter.py b/src/spikeinterface/sorters/basesorter.py index 3502d27548..28948f81cc 100644 --- a/src/spikeinterface/sorters/basesorter.py +++ b/src/spikeinterface/sorters/basesorter.py @@ -145,9 +145,8 @@ def initialize_folder(cls, recording, output_folder, verbose, remove_existing_fo elif recording.check_serializability("pickle"): recording.dump(output_folder / "spikeinterface_recording.pickle", relative_to=output_folder) else: - # TODO: deprecate and finally remove this after 0.100 - d = {"warning": "The recording is not serializable to json"} - rec_file.write_text(json.dumps(d, indent=4), encoding="utf8") + raise RuntimeError("This recording is not serializable and so can not be sorted. Consider `recording.save()` to save a " + "compatible binary file.") return output_folder From 5e832c9523a218a032c75d774c71c9188a8114ae Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 15 Oct 2024 21:19:30 +0000 Subject: [PATCH 05/42] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sorters/basesorter.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sorters/basesorter.py b/src/spikeinterface/sorters/basesorter.py index 28948f81cc..c59fa29c05 100644 --- a/src/spikeinterface/sorters/basesorter.py +++ b/src/spikeinterface/sorters/basesorter.py @@ -145,8 +145,10 @@ def initialize_folder(cls, recording, output_folder, verbose, remove_existing_fo elif recording.check_serializability("pickle"): recording.dump(output_folder / "spikeinterface_recording.pickle", relative_to=output_folder) else: - raise RuntimeError("This recording is not serializable and so can not be sorted. Consider `recording.save()` to save a " - "compatible binary file.") + raise RuntimeError( + "This recording is not serializable and so can not be sorted. Consider `recording.save()` to save a " + "compatible binary file." + ) return output_folder From b62e02ba244c40f9eb86e2206bc844e2be339a2d Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Mon, 21 Oct 2024 10:39:37 +0100 Subject: [PATCH 06/42] export_report without waveforms --- src/spikeinterface/widgets/unit_summary.py | 44 +++++++++++++--------- 1 file changed, 26 insertions(+), 18 deletions(-) diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index 755e60ccbf..8aea6fd690 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -107,15 +107,18 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): # and use custum grid spec fig = self.figure nrows = 2 - ncols = 3 + ncols = 2 if sorting_analyzer.has_extension("correlograms") or sorting_analyzer.has_extension("spike_amplitudes"): ncols += 1 + if sorting_analyzer.has_extension("waveforms"): + ncols += 1 if sorting_analyzer.has_extension("spike_amplitudes"): nrows += 1 gs = fig.add_gridspec(nrows, ncols) + col_counter = 0 if sorting_analyzer.has_extension("unit_locations"): - ax1 = fig.add_subplot(gs[:2, 0]) + ax1 = fig.add_subplot(gs[:2, col_counter]) # UnitLocationsPlotter().do_plot(dp.plot_data_unit_locations, ax=ax1) w = UnitLocationsWidget( sorting_analyzer, @@ -126,6 +129,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax=ax1, **unitlocationswidget_kwargs, ) + col_counter = col_counter + 1 unit_locations = sorting_analyzer.get_extension("unit_locations").get_data(outputs="by_unit") unit_location = unit_locations[unit_id] @@ -136,12 +140,13 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax1.set_xlabel(None) ax1.set_ylabel(None) - ax2 = fig.add_subplot(gs[:2, 1]) + ax2 = fig.add_subplot(gs[:2, col_counter]) w = UnitWaveformsWidget( sorting_analyzer, unit_ids=[unit_id], unit_colors=unit_colors, plot_templates=True, + plot_waveforms=sorting_analyzer.has_extension("waveforms"), same_axis=True, plot_legend=False, sparsity=sparsity, @@ -149,24 +154,27 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax=ax2, **unitwaveformswidget_kwargs, ) + col_counter = col_counter + 1 ax2.set_title(None) - ax3 = fig.add_subplot(gs[:2, 2]) - UnitWaveformDensityMapWidget( - sorting_analyzer, - unit_ids=[unit_id], - unit_colors=unit_colors, - use_max_channel=True, - same_axis=False, - backend="matplotlib", - ax=ax3, - **unitwaveformdensitymapwidget_kwargs, - ) - ax3.set_ylabel(None) + if sorting_analyzer.has_extension("waveforms"): + ax3 = fig.add_subplot(gs[:2, col_counter]) + UnitWaveformDensityMapWidget( + sorting_analyzer, + unit_ids=[unit_id], + unit_colors=unit_colors, + use_max_channel=True, + same_axis=False, + backend="matplotlib", + ax=ax3, + **unitwaveformdensitymapwidget_kwargs, + ) + ax3.set_ylabel(None) + col_counter = col_counter + 1 if sorting_analyzer.has_extension("correlograms"): - ax4 = fig.add_subplot(gs[:2, 3]) + ax4 = fig.add_subplot(gs[:2, col_counter]) AutoCorrelogramsWidget( sorting_analyzer, unit_ids=[unit_id], @@ -180,8 +188,8 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax4.set_yticks([]) if sorting_analyzer.has_extension("spike_amplitudes"): - ax5 = fig.add_subplot(gs[2, :3]) - ax6 = fig.add_subplot(gs[2, 3]) + ax5 = fig.add_subplot(gs[2, :col_counter]) + ax6 = fig.add_subplot(gs[2, col_counter]) axes = np.array([ax5, ax6]) AmplitudesWidget( sorting_analyzer, From ebc6164064c72490cc1b9f539b1b2eb5f0eae9ea Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Mon, 21 Oct 2024 09:12:37 -0600 Subject: [PATCH 07/42] Update pyproject.toml Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 403988c980..fc09ad9198 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ authors = [ ] description = "Python toolkit for analysis, visualization, and comparison of spike sorting output" readme = "README.md" -requires-python = ">=3.9,<3.13" # Only numpy 2.0 supported on python 3.13 for windows. We need to wait for fix on neo +requires-python = ">=3.9,<3.13" # Only numpy 2.1 supported on python 3.13 for windows. We need to wait for fix on neo classifiers = [ "Programming Language :: Python :: 3 :: Only", "License :: OSI Approved :: MIT License", From b5bd2fb1f4459689e05893a0524a8dab3543b5e8 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Wed, 23 Oct 2024 09:21:01 -0400 Subject: [PATCH 08/42] add error messaging around use of get data in templates --- .../core/analyzer_extension_core.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index bc5de63d07..e0b267ae72 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -380,7 +380,12 @@ def _set_params(self, ms_before: float = 1.0, ms_after: float = 2.0, operators=N assert isinstance(operators, list) for operator in operators: if isinstance(operator, str): - assert operator in ("average", "std", "median", "mad") + if operator not in ("average", "std", "median", "mad"): + error_msg = ( + f"You have entered an operator {operator} in your `operators` argument which is " + f"not supported. Please use any of ['average', 'std', 'median', 'mad'] instead." + ) + raise ValueError(error_msg) else: assert isinstance(operator, (list, tuple)) assert len(operator) == 2 @@ -549,9 +554,17 @@ def _get_data(self, operator="average", percentile=None, outputs="numpy"): if operator != "percentile": key = operator else: - assert percentile is not None, "You must provide percentile=..." + assert percentile is not None, "You must provide percentile=... if `operator=percentile`" key = f"percentile_{percentile}" + if key not in self.data.keys(): + error_msg = ( + f"You have entered `operator={key}`, but the only operators calculated are " + f"{list(self.data.keys())}. Please use one of these as your `operator` in the " + f"`get_data` function." + ) + raise ValueError(error_msg) + templates_array = self.data[key] if outputs == "numpy": From 22d19d52217594a129fc474bde119a2370e8d0e4 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Wed, 23 Oct 2024 09:36:17 -0400 Subject: [PATCH 09/42] more docs stuff --- .../core/analyzer_extension_core.py | 62 ++++++++++--------- 1 file changed, 34 insertions(+), 28 deletions(-) diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index e0b267ae72..7d644c9c00 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -22,21 +22,23 @@ class ComputeRandomSpikes(AnalyzerExtension): """ - AnalyzerExtension that select some random spikes. + AnalyzerExtension that select somes random spikes. + This is allows for a subsampling of spikes for further calculations and is important + for managing that amount of memory and speed of computation in the analyzer. This will be used by the `waveforms`/`templates` extensions. - This internally use `random_spikes_selection()` parameters are the same. + This internally use `random_spikes_selection()` parameters. Parameters ---------- - method: "uniform" | "all", default: "uniform" + method : "uniform" | "all", default: "uniform" The method to select the spikes - max_spikes_per_unit: int, default: 500 + max_spikes_per_unit : int, default: 500 The maximum number of spikes per unit, ignored if method="all" - margin_size: int, default: None + margin_size : int, default: None A margin on each border of segments to avoid border spikes, ignored if method="all" - seed: int or None, default: None + seed : int or None, default: None A seed for the random generator, ignored if method="all" Returns @@ -104,7 +106,7 @@ def get_random_spikes(self): return self._some_spikes def get_selected_indices_in_spike_train(self, unit_id, segment_index): - # usefull for Waveforms extractor backwars compatibility + # useful for Waveforms extractor backwars compatibility # In Waveforms extractor "selected_spikes" was a dict (key: unit_id) of list (segment_index) of indices of spikes in spiketrain sorting = self.sorting_analyzer.sorting random_spikes_indices = self.data["random_spikes_indices"] @@ -133,16 +135,16 @@ class ComputeWaveforms(AnalyzerExtension): Parameters ---------- - ms_before: float, default: 1.0 + ms_before : float, default: 1.0 The number of ms to extract before the spike events - ms_after: float, default: 2.0 + ms_after : float, default: 2.0 The number of ms to extract after the spike events - dtype: None | dtype, default: None + dtype : None | dtype, default: None The dtype of the waveforms. If None, the dtype of the recording is used. Returns ------- - waveforms: np.ndarray + waveforms : np.ndarray Array with computed waveforms with shape (num_random_spikes, num_samples, num_channels) """ @@ -410,9 +412,13 @@ def _run(self, verbose=False, **job_kwargs): self._compute_and_append_from_waveforms(self.params["operators"]) else: - for operator in self.params["operators"]: - if operator not in ("average", "std"): - raise ValueError(f"Computing templates with operators {operator} needs the 'waveforms' extension") + bad_operator_list = [ + operator for operator in self.params["operators"] if operator not in ("average", "std") + ] + if len(bad_operator_list) > 0: + raise ValueError( + f"Computing templates with operators {bad_operator_list} requires the 'waveforms' extension" + ) recording = self.sorting_analyzer.recording sorting = self.sorting_analyzer.sorting @@ -446,7 +452,7 @@ def _run(self, verbose=False, **job_kwargs): def _compute_and_append_from_waveforms(self, operators): if not self.sorting_analyzer.has_extension("waveforms"): - raise ValueError(f"Computing templates with operators {operators} needs the 'waveforms' extension") + raise ValueError(f"Computing templates with operators {operators} requires the 'waveforms' extension") unit_ids = self.sorting_analyzer.unit_ids channel_ids = self.sorting_analyzer.channel_ids @@ -471,7 +477,7 @@ def _compute_and_append_from_waveforms(self, operators): assert self.sorting_analyzer.has_extension( "random_spikes" - ), "compute templates requires the random_spikes extension. You can run sorting_analyzer.get_random_spikes()" + ), "compute 'templates' requires the random_spikes extension. You can run sorting_analyzer.compute('random_spikes')" some_spikes = self.sorting_analyzer.get_extension("random_spikes").get_random_spikes() for unit_index, unit_id in enumerate(unit_ids): spike_mask = some_spikes["unit_index"] == unit_index @@ -579,7 +585,7 @@ def _get_data(self, operator="average", percentile=None, outputs="numpy"): probe=self.sorting_analyzer.get_probe(), ) else: - raise ValueError("outputs must be numpy or Templates") + raise ValueError("outputs must be `numpy` or `Templates`") def get_templates(self, unit_ids=None, operator="average", percentile=None, save=True, outputs="numpy"): """ @@ -589,26 +595,26 @@ def get_templates(self, unit_ids=None, operator="average", percentile=None, save Parameters ---------- - unit_ids: list or None + unit_ids : list or None Unit ids to retrieve waveforms for - operator: "average" | "median" | "std" | "percentile", default: "average" + operator : "average" | "median" | "std" | "percentile", default: "average" The operator to compute the templates - percentile: float, default: None + percentile : float, default: None Percentile to use for operator="percentile" - save: bool, default True + save : bool, default: True In case, the operator is not computed yet it can be saved to folder or zarr - outputs: "numpy" | "Templates" + outputs : "numpy" | "Templates", default: "numpy" Whether to return a numpy array or a Templates object Returns ------- - templates: np.array + templates : np.array | Templates The returned templates (num_units, num_samples, num_channels) """ if operator != "percentile": key = operator else: - assert percentile is not None, "You must provide percentile=..." + assert percentile is not None, "You must provide percentile=... if `operator='percentile'`" key = f"pencentile_{percentile}" if key in self.data: @@ -645,7 +651,7 @@ def get_templates(self, unit_ids=None, operator="average", percentile=None, save is_scaled=self.sorting_analyzer.return_scaled, ) else: - raise ValueError("outputs must be numpy or Templates") + raise ValueError("`outputs` must be 'numpy' or 'Templates'") def get_unit_template(self, unit_id, operator="average"): """ @@ -655,7 +661,7 @@ def get_unit_template(self, unit_id, operator="average"): ---------- unit_id: str | int Unit id to retrieve waveforms for - operator: str + operator: str, default: "average" The operator to compute the templates Returns @@ -713,13 +719,13 @@ def _set_params(self, num_chunks_per_segment=20, chunk_size=10000, seed=None): return params def _select_extension_data(self, unit_ids): - # this do not depend on units + # this does not depend on units return self.data def _merge_extension_data( self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs ): - # this do not depend on units + # this does not depend on units return self.data.copy() def _run(self, verbose=False): From b1f11fbdcf3a9e94e77030f7da83368ed9ccea73 Mon Sep 17 00:00:00 2001 From: Zach McKenzie <92116279+zm711@users.noreply.github.com> Date: Wed, 23 Oct 2024 09:38:03 -0400 Subject: [PATCH 10/42] fix typo --- src/spikeinterface/core/analyzer_extension_core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index 7d644c9c00..2d9924554c 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -23,7 +23,7 @@ class ComputeRandomSpikes(AnalyzerExtension): """ AnalyzerExtension that select somes random spikes. - This is allows for a subsampling of spikes for further calculations and is important + This allows for a subsampling of spikes for further calculations and is important for managing that amount of memory and speed of computation in the analyzer. This will be used by the `waveforms`/`templates` extensions. From f66ae7fc9c6cf66c5ca35d11dddedfbb2180080d Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Fri, 25 Oct 2024 08:28:41 +0100 Subject: [PATCH 11/42] Compute covariance matrix in float64. --- src/spikeinterface/preprocessing/whiten.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/preprocessing/whiten.py b/src/spikeinterface/preprocessing/whiten.py index 195969ff79..91c74c423f 100644 --- a/src/spikeinterface/preprocessing/whiten.py +++ b/src/spikeinterface/preprocessing/whiten.py @@ -124,7 +124,7 @@ def __init__(self, parent_recording_segment, W, M, dtype, int_scale): def get_traces(self, start_frame, end_frame, channel_indices): traces = self.parent_recording_segment.get_traces(start_frame, end_frame, slice(None)) traces_dtype = traces.dtype - # if uint --> force int + # if uint --> force float if traces_dtype.kind == "u": traces = traces.astype("float32") @@ -185,6 +185,7 @@ def compute_whitening_matrix( """ random_data = get_random_data_chunks(recording, concatenated=True, return_scaled=False, **random_chunk_kwargs) + random_data = random_data.astype(np.float64) regularize_kwargs = regularize_kwargs if regularize_kwargs is not None else {"method": "GraphicalLassoCV"} From b5c260aed812d6fb6202ffdf13d35e28d79ff4e9 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Fri, 25 Oct 2024 09:16:26 +0100 Subject: [PATCH 12/42] Update docstring. --- src/spikeinterface/preprocessing/whiten.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/preprocessing/whiten.py b/src/spikeinterface/preprocessing/whiten.py index 91c74c423f..1c81f2ae42 100644 --- a/src/spikeinterface/preprocessing/whiten.py +++ b/src/spikeinterface/preprocessing/whiten.py @@ -19,6 +19,8 @@ class WhitenRecording(BasePreprocessor): recording : RecordingExtractor The recording extractor to be whitened. dtype : None or dtype, default: None + Datatype of the output recording (covariance matrix estimation + and whitening are performed in float64. If None the the parent dtype is kept. For integer dtype a int_scale must be also given. mode : "global" | "local", default: "global" @@ -74,7 +76,8 @@ def __init__( dtype_ = fix_dtype(recording, dtype) if dtype_.kind == "i": - assert int_scale is not None, "For recording with dtype=int you must set dtype=float32 OR set a int_scale" + assert int_scale is not None, ("For recording with dtype=int you must set the output dtype to float " + " OR set a int_scale") if W is not None: W = np.asarray(W) From 18cfb2b385d9cf5e18d622097a41631d94a0e9a8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 25 Oct 2024 08:18:56 +0000 Subject: [PATCH 13/42] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/preprocessing/whiten.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/preprocessing/whiten.py b/src/spikeinterface/preprocessing/whiten.py index 1c81f2ae42..4e3135c3e9 100644 --- a/src/spikeinterface/preprocessing/whiten.py +++ b/src/spikeinterface/preprocessing/whiten.py @@ -76,8 +76,9 @@ def __init__( dtype_ = fix_dtype(recording, dtype) if dtype_.kind == "i": - assert int_scale is not None, ("For recording with dtype=int you must set the output dtype to float " - " OR set a int_scale") + assert int_scale is not None, ( + "For recording with dtype=int you must set the output dtype to float " " OR set a int_scale" + ) if W is not None: W = np.asarray(W) From 98e5db95aa36a415d520cfe758113fc7c5db9bac Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 25 Oct 2024 10:42:58 +0200 Subject: [PATCH 14/42] recording_slices in run_node_pipeline() --- src/spikeinterface/core/job_tools.py | 22 +++++++++--------- src/spikeinterface/core/node_pipeline.py | 7 +++++- src/spikeinterface/core/recording_tools.py | 2 +- .../core/tests/test_node_pipeline.py | 23 +++++++++++++++---- .../sortingcomponents/peak_detection.py | 6 +++++ 5 files changed, 42 insertions(+), 18 deletions(-) diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index 27f05bb36b..7a6172369b 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -149,12 +149,12 @@ def divide_segment_into_chunks(num_frames, chunk_size): def divide_recording_into_chunks(recording, chunk_size): - all_chunks = [] + recording_slices = [] for segment_index in range(recording.get_num_segments()): num_frames = recording.get_num_samples(segment_index) chunks = divide_segment_into_chunks(num_frames, chunk_size) - all_chunks.extend([(segment_index, frame_start, frame_stop) for frame_start, frame_stop in chunks]) - return all_chunks + recording_slices.extend([(segment_index, frame_start, frame_stop) for frame_start, frame_stop in chunks]) + return recording_slices def ensure_n_jobs(recording, n_jobs=1): @@ -387,13 +387,13 @@ def __init__( f"chunk_duration={chunk_duration_str}", ) - def run(self, all_chunks=None): + def run(self, recording_slices=None): """ Runs the defined jobs. """ - if all_chunks is None: - all_chunks = divide_recording_into_chunks(self.recording, self.chunk_size) + if recording_slices is None: + recording_slices = divide_recording_into_chunks(self.recording, self.chunk_size) if self.handle_returns: returns = [] @@ -402,17 +402,17 @@ def run(self, all_chunks=None): if self.n_jobs == 1: if self.progress_bar: - all_chunks = tqdm(all_chunks, ascii=True, desc=self.job_name) + recording_slices = tqdm(recording_slices, ascii=True, desc=self.job_name) worker_ctx = self.init_func(*self.init_args) - for segment_index, frame_start, frame_stop in all_chunks: + for segment_index, frame_start, frame_stop in recording_slices: res = self.func(segment_index, frame_start, frame_stop, worker_ctx) if self.handle_returns: returns.append(res) if self.gather_func is not None: self.gather_func(res) else: - n_jobs = min(self.n_jobs, len(all_chunks)) + n_jobs = min(self.n_jobs, len(recording_slices)) # parallel with ProcessPoolExecutor( @@ -421,10 +421,10 @@ def run(self, all_chunks=None): mp_context=mp.get_context(self.mp_context), initargs=(self.func, self.init_func, self.init_args, self.max_threads_per_process), ) as executor: - results = executor.map(function_wrapper, all_chunks) + results = executor.map(function_wrapper, recording_slices) if self.progress_bar: - results = tqdm(results, desc=self.job_name, total=len(all_chunks)) + results = tqdm(results, desc=self.job_name, total=len(recording_slices)) for res in results: if self.handle_returns: diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index d90a20902d..8ca4ba7f3a 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -489,6 +489,7 @@ def run_node_pipeline( names=None, verbose=False, skip_after_n_peaks=None, + recording_slices=None, ): """ Machinery to compute in parallel operations on peaks and traces. @@ -540,6 +541,10 @@ def run_node_pipeline( skip_after_n_peaks : None | int Skip the computation after n_peaks. This is not an exact because internally this skip is done per worker in average. + recording_slices : None | list[tuple] + Optionaly give a list of slices to run the pipeline only on some chunks of the recording. + It must be a list of (segment_index, frame_start, frame_stop). + If None (default), the entire recording is computed. Returns ------- @@ -578,7 +583,7 @@ def run_node_pipeline( **job_kwargs, ) - processor.run() + processor.run(recording_slices=recording_slices) outs = gather_func.finalize_buffers(squeeze_output=squeeze_output) return outs diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index 2ab74ce51e..4aabbfd587 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -806,7 +806,7 @@ def append_noise_chunk(res): gather_func=append_noise_chunk, **job_kwargs, ) - executor.run(all_chunks=recording_slices) + executor.run(recording_slices=recording_slices) noise_levels_chunks = np.stack(noise_levels_chunks) noise_levels = np.mean(noise_levels_chunks, axis=0) diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index deef2291c6..400a71c424 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -4,7 +4,7 @@ import shutil from spikeinterface import create_sorting_analyzer, get_template_extremum_channel, generate_ground_truth_recording - +from spikeinterface.core.job_tools import divide_recording_into_chunks # from spikeinterface.sortingcomponents.peak_detection import detect_peaks from spikeinterface.core.node_pipeline import ( @@ -191,8 +191,8 @@ def test_run_node_pipeline(cache_folder_creation): unpickled_node = pickle.loads(pickled_node) -def test_skip_after_n_peaks(): - recording, sorting = generate_ground_truth_recording(num_channels=10, num_units=10, durations=[10.0]) +def test_skip_after_n_peaks_and_recording_slices(): + recording, sorting = generate_ground_truth_recording(num_channels=10, num_units=10, durations=[10.0], seed=2205) # job_kwargs = dict(chunk_duration="0.5s", n_jobs=2, progress_bar=False) job_kwargs = dict(chunk_duration="0.5s", n_jobs=1, progress_bar=False) @@ -211,18 +211,31 @@ def test_skip_after_n_peaks(): node1 = AmplitudeExtractionNode(recording, parents=[node0], param0=6.6, return_output=True) nodes = [node0, node1] + # skip skip_after_n_peaks = 30 some_amplitudes = run_node_pipeline( recording, nodes, job_kwargs, gather_mode="memory", skip_after_n_peaks=skip_after_n_peaks ) - assert some_amplitudes.size >= skip_after_n_peaks assert some_amplitudes.size < spikes.size + # slices : 1 every 4 + recording_slices = divide_recording_into_chunks(recording, 10_000) + recording_slices = recording_slices[::4] + some_amplitudes = run_node_pipeline( + recording, nodes, job_kwargs, gather_mode="memory", recording_slices=recording_slices + ) + tolerance = 1.2 + assert some_amplitudes.size < (spikes.size // 4) * tolerance + + + + + # the following is for testing locally with python or ipython. It is not used in ci or with pytest. if __name__ == "__main__": # folder = Path("./cache_folder/core") # test_run_node_pipeline(folder) - test_skip_after_n_peaks() + test_skip_after_n_peaks_and_recording_slices() diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py index 5b1d33b334..233b16dcf7 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection.py @@ -57,6 +57,7 @@ def detect_peaks( folder=None, names=None, skip_after_n_peaks=None, + recording_slices=None, **kwargs, ): """Peak detection based on threshold crossing in term of k x MAD. @@ -83,6 +84,10 @@ def detect_peaks( skip_after_n_peaks : None | int Skip the computation after n_peaks. This is not an exact because internally this skip is done per worker in average. + recording_slices : None | list[tuple] + Optionaly give a list of slices to run the pipeline only on some chunks of the recording. + It must be a list of (segment_index, frame_start, frame_stop). + If None (default), the entire recording is computed. {method_doc} {job_doc} @@ -135,6 +140,7 @@ def detect_peaks( folder=folder, names=names, skip_after_n_peaks=skip_after_n_peaks, + recording_slices=recording_slices, ) return outs From aaa689fa9174e8576550528224431b9ea3e32759 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 25 Oct 2024 08:47:02 +0000 Subject: [PATCH 15/42] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/tests/test_node_pipeline.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index 400a71c424..028eaecf12 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -229,10 +229,6 @@ def test_skip_after_n_peaks_and_recording_slices(): assert some_amplitudes.size < (spikes.size // 4) * tolerance - - - - # the following is for testing locally with python or ipython. It is not used in ci or with pytest. if __name__ == "__main__": # folder = Path("./cache_folder/core") From 3406f85ba35209dd557ca9c0b0c15c5c84219e7a Mon Sep 17 00:00:00 2001 From: Zach McKenzie <92116279+zm711@users.noreply.github.com> Date: Fri, 25 Oct 2024 06:39:35 -0400 Subject: [PATCH 16/42] Joe's comments Co-authored-by: Joe Ziminski <55797454+JoeZiminski@users.noreply.github.com> --- src/spikeinterface/core/analyzer_extension_core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index 2d9924554c..55e0e34dcc 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -28,7 +28,7 @@ class ComputeRandomSpikes(AnalyzerExtension): This will be used by the `waveforms`/`templates` extensions. - This internally use `random_spikes_selection()` parameters. + This internally uses `random_spikes_selection()` parameters. Parameters ---------- @@ -106,7 +106,7 @@ def get_random_spikes(self): return self._some_spikes def get_selected_indices_in_spike_train(self, unit_id, segment_index): - # useful for Waveforms extractor backwars compatibility + # useful for WaveformExtractor backwards compatibility # In Waveforms extractor "selected_spikes" was a dict (key: unit_id) of list (segment_index) of indices of spikes in spiketrain sorting = self.sorting_analyzer.sorting random_spikes_indices = self.data["random_spikes_indices"] From f55a9405040310064f716015f8d9b0c976b97923 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Mon, 28 Oct 2024 14:28:05 +0000 Subject: [PATCH 17/42] Add 'shift start time' function. --- src/spikeinterface/core/baserecording.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 5e2e9e4014..b8a0420794 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -509,6 +509,26 @@ def reset_times(self): rs.t_start = None rs.sampling_frequency = self.sampling_frequency + def shift_start_time(self, shift, segment_index=None): + """ + Shift the starting time of the times. + + shift : int | float + The shift to apply to the first time point. If positive, + the current start time will be increased by `shift`. If + negative, the start time will be decreased. + + segment_index : int | None + The segment on which to shift the times. + """ + segment_index = self._check_segment_index(segment_index) + rs = self._recording_segments[segment_index] + + if self.has_time_vector(): + rs.time_vector += shift + else: + rs.t_start += shift + def sample_index_to_time(self, sample_ind, segment_index=None): """ Transform sample index into time in seconds From d17181f3bd68f602780ad99e1b618aa3f793b8ad Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 29 Oct 2024 13:08:56 +0100 Subject: [PATCH 18/42] Update src/spikeinterface/preprocessing/whiten.py --- src/spikeinterface/preprocessing/whiten.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/preprocessing/whiten.py b/src/spikeinterface/preprocessing/whiten.py index 4e3135c3e9..505e8a330a 100644 --- a/src/spikeinterface/preprocessing/whiten.py +++ b/src/spikeinterface/preprocessing/whiten.py @@ -20,7 +20,7 @@ class WhitenRecording(BasePreprocessor): The recording extractor to be whitened. dtype : None or dtype, default: None Datatype of the output recording (covariance matrix estimation - and whitening are performed in float64. + and whitening are performed in float64). If None the the parent dtype is kept. For integer dtype a int_scale must be also given. mode : "global" | "local", default: "global" From 7c953c2b1347dfbbb4058f5a9b7462f90d22c1dd Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Tue, 29 Oct 2024 12:19:26 +0000 Subject: [PATCH 19/42] Respond to Joe; add template_similarity check --- src/spikeinterface/exporters/report.py | 8 +-- src/spikeinterface/widgets/unit_summary.py | 82 +++++++++++----------- 2 files changed, 45 insertions(+), 45 deletions(-) diff --git a/src/spikeinterface/exporters/report.py b/src/spikeinterface/exporters/report.py index 3a4be9213a..484da83342 100644 --- a/src/spikeinterface/exporters/report.py +++ b/src/spikeinterface/exporters/report.py @@ -20,10 +20,10 @@ def export_report( **job_kwargs, ): """ - Exports a SI spike sorting report. The report includes summary figures of the spike sorting output - (e.g. amplitude distributions, unit localization and depth VS amplitude) as well as unit-specific reports, - that include waveforms, templates, template maps, ISI distributions, and more. - + Exports a SI spike sorting report. The report includes summary figures of the spike sorting output. + What is plotted depends on what has been calculated. Unit locations and unit waveforms are always included. + Unit waveform densities, correlograms and spike amplitudes are plotted if `waveforms`, `correlograms` + and 'template_similarity', and `spike_amplitudes` have been computed for the given `sorting_analyzer`. Parameters ---------- diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index 8aea6fd690..d8cbeb7bb3 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -108,7 +108,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): fig = self.figure nrows = 2 ncols = 2 - if sorting_analyzer.has_extension("correlograms") or sorting_analyzer.has_extension("spike_amplitudes"): + if sorting_analyzer.has_extension("correlograms") and sorting_analyzer.has_extension("template_similarity"): ncols += 1 if sorting_analyzer.has_extension("waveforms"): ncols += 1 @@ -117,31 +117,30 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): gs = fig.add_gridspec(nrows, ncols) col_counter = 0 - if sorting_analyzer.has_extension("unit_locations"): - ax1 = fig.add_subplot(gs[:2, col_counter]) - # UnitLocationsPlotter().do_plot(dp.plot_data_unit_locations, ax=ax1) - w = UnitLocationsWidget( - sorting_analyzer, - unit_ids=[unit_id], - unit_colors=unit_colors, - plot_legend=False, - backend="matplotlib", - ax=ax1, - **unitlocationswidget_kwargs, - ) - col_counter = col_counter + 1 - - unit_locations = sorting_analyzer.get_extension("unit_locations").get_data(outputs="by_unit") - unit_location = unit_locations[unit_id] - x, y = unit_location[0], unit_location[1] - ax1.set_xlim(x - 80, x + 80) - ax1.set_ylim(y - 250, y + 250) - ax1.set_xticks([]) - ax1.set_xlabel(None) - ax1.set_ylabel(None) - - ax2 = fig.add_subplot(gs[:2, col_counter]) - w = UnitWaveformsWidget( + # Unit locations and unit waveform plots are always generated + ax_unit_locations = fig.add_subplot(gs[:2, col_counter]) + _ = UnitLocationsWidget( + sorting_analyzer, + unit_ids=[unit_id], + unit_colors=unit_colors, + plot_legend=False, + backend="matplotlib", + ax=ax_unit_locations, + **unitlocationswidget_kwargs, + ) + col_counter += 1 + + unit_locations = sorting_analyzer.get_extension("unit_locations").get_data(outputs="by_unit") + unit_location = unit_locations[unit_id] + x, y = unit_location[0], unit_location[1] + ax_unit_locations.set_xlim(x - 80, x + 80) + ax_unit_locations.set_ylim(y - 250, y + 250) + ax_unit_locations.set_xticks([]) + ax_unit_locations.set_xlabel(None) + ax_unit_locations.set_ylabel(None) + + ax_unit_waveforms = fig.add_subplot(gs[:2, col_counter]) + _ = UnitWaveformsWidget( sorting_analyzer, unit_ids=[unit_id], unit_colors=unit_colors, @@ -151,15 +150,15 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): plot_legend=False, sparsity=sparsity, backend="matplotlib", - ax=ax2, + ax=ax_unit_waveforms, **unitwaveformswidget_kwargs, ) - col_counter = col_counter + 1 + col_counter += 1 - ax2.set_title(None) + ax_unit_waveforms.set_title(None) if sorting_analyzer.has_extension("waveforms"): - ax3 = fig.add_subplot(gs[:2, col_counter]) + ax_waveform_density = fig.add_subplot(gs[:2, col_counter]) UnitWaveformDensityMapWidget( sorting_analyzer, unit_ids=[unit_id], @@ -167,30 +166,31 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): use_max_channel=True, same_axis=False, backend="matplotlib", - ax=ax3, + ax=ax_waveform_density, **unitwaveformdensitymapwidget_kwargs, ) - ax3.set_ylabel(None) - col_counter = col_counter + 1 + col_counter += 1 + ax_waveform_density.set_ylabel(None) - if sorting_analyzer.has_extension("correlograms"): - ax4 = fig.add_subplot(gs[:2, col_counter]) + if sorting_analyzer.has_extension("correlograms") and sorting_analyzer.has_extension("template_similarity"): + ax_correlograms = fig.add_subplot(gs[:2, col_counter]) AutoCorrelogramsWidget( sorting_analyzer, unit_ids=[unit_id], unit_colors=unit_colors, backend="matplotlib", - ax=ax4, + ax=ax_correlograms, **autocorrelogramswidget_kwargs, ) + col_counter += 1 - ax4.set_title(None) - ax4.set_yticks([]) + ax_correlograms.set_title(None) + ax_correlograms.set_yticks([]) if sorting_analyzer.has_extension("spike_amplitudes"): - ax5 = fig.add_subplot(gs[2, :col_counter]) - ax6 = fig.add_subplot(gs[2, col_counter]) - axes = np.array([ax5, ax6]) + ax_spike_amps = fig.add_subplot(gs[2, : col_counter - 1]) + ax_amps_distribution = fig.add_subplot(gs[2, col_counter - 1]) + axes = np.array([ax_spike_amps, ax_amps_distribution]) AmplitudesWidget( sorting_analyzer, unit_ids=[unit_id], From 12538cc646f47b162b73190326d5b541121b2c1a Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Thu, 31 Oct 2024 17:49:29 +0000 Subject: [PATCH 20/42] Revert to float32. --- src/spikeinterface/preprocessing/whiten.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/preprocessing/whiten.py b/src/spikeinterface/preprocessing/whiten.py index 505e8a330a..b9c106a5a2 100644 --- a/src/spikeinterface/preprocessing/whiten.py +++ b/src/spikeinterface/preprocessing/whiten.py @@ -20,7 +20,7 @@ class WhitenRecording(BasePreprocessor): The recording extractor to be whitened. dtype : None or dtype, default: None Datatype of the output recording (covariance matrix estimation - and whitening are performed in float64). + and whitening are performed in float32). If None the the parent dtype is kept. For integer dtype a int_scale must be also given. mode : "global" | "local", default: "global" @@ -189,7 +189,7 @@ def compute_whitening_matrix( """ random_data = get_random_data_chunks(recording, concatenated=True, return_scaled=False, **random_chunk_kwargs) - random_data = random_data.astype(np.float64) + random_data = random_data.astype(np.float32) regularize_kwargs = regularize_kwargs if regularize_kwargs is not None else {"method": "GraphicalLassoCV"} From 035d61c38dbdb453a6461de424028d1466367bda Mon Sep 17 00:00:00 2001 From: Joe Ziminski <55797454+JoeZiminski@users.noreply.github.com> Date: Thu, 31 Oct 2024 18:02:03 +0000 Subject: [PATCH 21/42] Fix string format error. --- src/spikeinterface/preprocessing/whiten.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/preprocessing/whiten.py b/src/spikeinterface/preprocessing/whiten.py index b9c106a5a2..fa33975a68 100644 --- a/src/spikeinterface/preprocessing/whiten.py +++ b/src/spikeinterface/preprocessing/whiten.py @@ -77,7 +77,7 @@ def __init__( if dtype_.kind == "i": assert int_scale is not None, ( - "For recording with dtype=int you must set the output dtype to float " " OR set a int_scale" + "For recording with dtype=int you must set the output dtype to float OR set a int_scale" ) if W is not None: From 7b8d0a2c1c3e006d8a9a46257e0f06e034aa0a76 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 31 Oct 2024 18:02:31 +0000 Subject: [PATCH 22/42] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/preprocessing/whiten.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/preprocessing/whiten.py b/src/spikeinterface/preprocessing/whiten.py index fa33975a68..57400c1199 100644 --- a/src/spikeinterface/preprocessing/whiten.py +++ b/src/spikeinterface/preprocessing/whiten.py @@ -76,9 +76,9 @@ def __init__( dtype_ = fix_dtype(recording, dtype) if dtype_.kind == "i": - assert int_scale is not None, ( - "For recording with dtype=int you must set the output dtype to float OR set a int_scale" - ) + assert ( + int_scale is not None + ), "For recording with dtype=int you must set the output dtype to float OR set a int_scale" if W is not None: W = np.asarray(W) From f34da1aff682828dfba78cd17034c0fc2cb40fda Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Fri, 1 Nov 2024 14:03:13 +0000 Subject: [PATCH 23/42] Make new index page with hover CSS. --- doc/_static/css/custom.css | 20 +++ doc/conf.py | 6 +- doc/index.rst | 1 + doc/tutorials_custom_index.rst | 254 +++++++++++++++++++++++++++++++++ 4 files changed, 279 insertions(+), 2 deletions(-) create mode 100644 doc/_static/css/custom.css create mode 100644 doc/tutorials_custom_index.rst diff --git a/doc/_static/css/custom.css b/doc/_static/css/custom.css new file mode 100644 index 0000000000..0c51da539e --- /dev/null +++ b/doc/_static/css/custom.css @@ -0,0 +1,20 @@ +/* Center and make the title bold */ +.gallery-card .grid-item-card-title { + text-align: center; + font-weight: bold; +} + +/* Default style for hover content (hidden) */ +.gallery-card .hover-content { + display: none; + text-align: center; +} + +/* Show the hover content when hovering over the card */ +.gallery-card:hover .default-title { + display: none; +} + +.gallery-card:hover .hover-content { + display: block; +} diff --git a/doc/conf.py b/doc/conf.py index e3d58ca8f2..db16269991 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -109,8 +109,10 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -# html_static_path = ['_static'] - +html_static_path = ['_static'] +html_css_files = [ + 'css/custom.css', +] html_favicon = "images/favicon-32x32.png" diff --git a/doc/index.rst b/doc/index.rst index ed443e4200..57a0c95443 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -51,6 +51,7 @@ SpikeInterface is made of several modules to deal with different aspects of the overview get_started/index + tutorials_custom_index tutorials/index how_to/index modules/index diff --git a/doc/tutorials_custom_index.rst b/doc/tutorials_custom_index.rst new file mode 100644 index 0000000000..46a7bea630 --- /dev/null +++ b/doc/tutorials_custom_index.rst @@ -0,0 +1,254 @@ +.. This page provides a custom index to the 'Tutorials' page, rather than the default sphinx-gallery +.. generated page. The benefits of this are flexibility in design and inclusion of non-sphinx files in the index. +.. +.. To update this index with a new documentation page +.. 1) Copy the grid-item-card and associated ".. raw:: html" section. +.. 2) change :link: to a link to your page. If this is an `.rst` file, point to the rst file directly. +.. If it is a sphinx-gallery generated file, format the path as separated by underscore and prefix `sphx_glr`, +.. pointing to the .py file. e.g. `tutorials/my/page.py` -> `sphx_glr_tutorials_my_page.py +.. 3) Change :img-top: to point to the thumbnail image of your choosing. You can point to images generated +.. in the sphinx gallery page if you wish. +.. 4) In the `html` section, change the `default-title` to your pages title and `hover-content` to the subtitle. + +:orphan: + +TutorialsNew +============ + +Longer form tutorials about using SpikeInterface. Many of these are downloadable as notebooks or Python scripts so that you can "code along" with the tutorials. + +If you're new to SpikeInterface, we recommend trying out the :ref:`get_started/quickstart:Quickstart tutorial` first. + +Updating from legacy +-------------------- + +.. toctree:: + :maxdepth: 1 + + tutorials/waveform_extractor_to_sorting_analyzer + +Core tutorials +-------------- + +These tutorials focus on the :py:mod:`spikeinterface.core` module. + +.. grid:: 1 2 2 3 + :gutter: 2 + + .. grid-item-card:: + :link-type: ref + :link: sphx_glr_tutorials_core_plot_1_recording_extractor.py + :img-top: /tutorials/core/images/thumb/sphx_glr_plot_1_recording_extractor_thumb.png + :img-alt: Recording objects + :class-card: gallery-card + + .. raw:: html + +
Recording objects
+
Manage loaded recordings in SpikeInterface
+ + .. grid-item-card:: + :link-type: ref + :link: sphx_glr_tutorials_core_plot_2_sorting_extractor.py + :img-top: /tutorials/core/images/thumb/sphx_glr_plot_2_sorting_extractor_thumb.png + :img-alt: Sorting objects + :class-card: gallery-card + + .. raw:: html + +
Sorting objects
+
Explore sorting extractor features
+ + .. grid-item-card:: + :link-type: ref + :link: sphx_glr_tutorials_core_plot_3_handle_probe_info.py + :img-top: /tutorials/core/images/thumb/sphx_glr_plot_3_handle_probe_info_thumb.png + :img-alt: Handling probe information + :class-card: gallery-card + + .. raw:: html + +
Handling probe information
+
Handle and visualize probe information
+ + .. grid-item-card:: + :link-type: ref + :link: sphx_glr_tutorials_core_plot_4_sorting_analyzer.py + :img-top: /tutorials/core/images/thumb/sphx_glr_plot_4_sorting_analyzer_thumb.png + :img-alt: SortingAnalyzer + :class-card: gallery-card + + .. raw:: html + +
SortingAnalyzer
+
Analyze sorting results with ease
+ + .. grid-item-card:: + :link-type: ref + :link: sphx_glr_tutorials_core_plot_5_append_concatenate_segments.py + :img-top: /tutorials/core/images/thumb/sphx_glr_plot_5_append_concatenate_segments_thumb.png + :img-alt: Append/Concatenate segments + :class-card: gallery-card + + .. raw:: html + +
Append and/or concatenate segments
+
Combine segments efficiently
+ + .. grid-item-card:: + :link-type: ref + :link: sphx_glr_tutorials_core_plot_6_handle_times.py + :img-top: /tutorials/core/images/thumb/sphx_glr_plot_6_handle_times_thumb.png + :img-alt: Handle time information + :class-card: gallery-card + + .. raw:: html + +
Handle time information
+
Manage and analyze time information
+ +Extractors tutorials +-------------------- + +The :py:mod:`spikeinterface.extractors` module is designed to load and save recorded and sorted data, and to handle probe information. + +.. grid:: 1 2 2 3 + :gutter: 2 + + .. grid-item-card:: + :link-type: ref + :link: sphx_glr_tutorials_extractors_plot_1_read_various_formats.py + :img-top: /tutorials/extractors/images/thumb/sphx_glr_plot_1_read_various_formats_thumb.png + :img-alt: Read various formats + :class-card: gallery-card + + .. raw:: html + +
Read various formats
+
Read different recording formats efficiently
+ + .. grid-item-card:: + :link-type: ref + :link: sphx_glr_tutorials_extractors_plot_2_working_with_unscaled_traces.py + :img-top: /tutorials/extractors/images/thumb/sphx_glr_plot_2_working_with_unscaled_traces_thumb.png + :img-alt: Unscaled traces + :class-card: gallery-card + + .. raw:: html + +
Working with unscaled traces
+
Learn about managing unscaled traces
+ +Quality metrics tutorial +------------------------ + +The :code:`spikeinterface.qualitymetrics` module allows users to compute various quality metrics to assess the goodness of a spike sorting output. + +.. grid:: 1 2 2 3 + :gutter: 2 + + .. grid-item-card:: + :link-type: ref + :link: sphx_glr_tutorials_qualitymetrics_plot_3_quality_mertics.py + :img-top: /tutorials/qualitymetrics/images/thumb/sphx_glr_plot_3_quality_mertics_thumb.png + :img-alt: Quality Metrics + :class-card: gallery-card + + .. raw:: html + +
Quality Metrics
+
Evaluate sorting quality using metrics
+ + .. grid-item-card:: + :link-type: ref + :link: sphx_glr_tutorials_qualitymetrics_plot_4_curation.py + :img-top: /tutorials/qualitymetrics/images/thumb/sphx_glr_plot_4_curation_thumb.png + :img-alt: Curation Tutorial + :class-card: gallery-card + + .. raw:: html + +
Curation Tutorial
+
Learn how to curate spike sorting data
+ +Comparison tutorial +------------------- + +The :code:`spikeinterface.comparison` module allows you to compare sorter outputs or benchmark against ground truth. + +.. grid:: 1 2 2 3 + :gutter: 2 + + .. grid-item-card:: + :link-type: ref + :link: sphx_glr_tutorials_comparison_plot_5_comparison_sorter_weaknesses.py + :img-top: /tutorials/comparison/images/thumb/sphx_glr_plot_5_comparison_sorter_weaknesses_thumb.png + :img-alt: Sorter Comparison + :class-card: gallery-card + + .. raw:: html + +
Sorter Comparison
+
Compare sorter outputs and assess weaknesses
+ +Widgets tutorials +----------------- + +The :code:`widgets` module contains several plotting routines (widgets) for visualizing recordings, sorting data, probe layout, and more. + +.. grid:: 1 2 2 3 + :gutter: 2 + + .. grid-item-card:: + :link-type: ref + :link: sphx_glr_tutorials_widgets_plot_1_rec_gallery.py + :img-top: /tutorials/widgets/images/thumb/sphx_glr_plot_1_rec_gallery_thumb.png + :img-alt: Recording Widgets + :class-card: gallery-card + + .. raw:: html + +
RecordingExtractor Widgets
+
Visualize recordings with widgets
+ + .. grid-item-card:: + :link-type: ref + :link: sphx_glr_tutorials_widgets_plot_2_sort_gallery.py + :img-top: /tutorials/widgets/images/thumb/sphx_glr_plot_2_sort_gallery_thumb.png + :img-alt: Sorting Widgets + :class-card: gallery-card + + .. raw:: html + +
SortingExtractor Widgets
+
Explore sorting data using widgets
+ + .. grid-item-card:: + :link-type: ref + :link: sphx_glr_tutorials_widgets_plot_3_waveforms_gallery.py + :img-top: /tutorials/widgets/images/thumb/sphx_glr_plot_3_waveforms_gallery_thumb.png + :img-alt: Waveforms Widgets + :class-card: gallery-card + + .. raw:: html + +
Waveforms Widgets
+
Display waveforms using SpikeInterface
+ + .. grid-item-card:: + :link-type: ref + :link: sphx_glr_tutorials_widgets_plot_4_peaks_gallery.py + :img-top: /tutorials/widgets/images/thumb/sphx_glr_plot_4_peaks_gallery_thumb.png + :img-alt: Peaks Widgets + :class-card: gallery-card + + .. raw:: html + +
Peaks Widgets
+
Visualize detected peaks
+ +Download All Examples +--------------------- + +- :download:`Download all examples in Python source code ` +- :download:`Download all examples in Jupyter notebooks ` From 7aa93490cca20916338629518800a1cbf976b8ff Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Fri, 1 Nov 2024 14:53:22 +0000 Subject: [PATCH 24/42] Remove CSS and update development docs. --- doc/_static/css/custom.css | 20 ----- doc/conf.py | 6 +- doc/development/development.rst | 19 +++++ doc/index.rst | 1 - doc/tutorials_custom_index.rst | 128 +++++++++----------------------- 5 files changed, 56 insertions(+), 118 deletions(-) delete mode 100644 doc/_static/css/custom.css diff --git a/doc/_static/css/custom.css b/doc/_static/css/custom.css deleted file mode 100644 index 0c51da539e..0000000000 --- a/doc/_static/css/custom.css +++ /dev/null @@ -1,20 +0,0 @@ -/* Center and make the title bold */ -.gallery-card .grid-item-card-title { - text-align: center; - font-weight: bold; -} - -/* Default style for hover content (hidden) */ -.gallery-card .hover-content { - display: none; - text-align: center; -} - -/* Show the hover content when hovering over the card */ -.gallery-card:hover .default-title { - display: none; -} - -.gallery-card:hover .hover-content { - display: block; -} diff --git a/doc/conf.py b/doc/conf.py index db16269991..e3d58ca8f2 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -109,10 +109,8 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] -html_css_files = [ - 'css/custom.css', -] +# html_static_path = ['_static'] + html_favicon = "images/favicon-32x32.png" diff --git a/doc/development/development.rst b/doc/development/development.rst index a91818a271..1638c41243 100644 --- a/doc/development/development.rst +++ b/doc/development/development.rst @@ -213,6 +213,25 @@ We use Sphinx to build the documentation. To build the documentation locally, yo This will build the documentation in the :code:`doc/_build/html` folder. You can open the :code:`index.html` file in your browser to see the documentation. +Adding new documentation +------------------------ + +Documentation can be added as a +`sphinx-gallery `_ +python file ('tutorials') +or a +`sphinx rst `_ +file (all other sections). + +To add a new tutorial, add your ``.py`` file to ``spikeinterface/examples``. +Then, update the ``spikeinterface/doc/tutorials_custom_index.rst`` file +to make a new card linking to the page and an optional image. See +``tutorials_custom_index.rst`` header for more information. + +For other sections, write your documentation in ``.rst`` format and add +the page to the appropriate ``index.rst`` file found in the relevant +folder (e.g. ``how_to/index.rst``). + How to run code coverage locally -------------------------------- To run code coverage locally, you can use the following command: diff --git a/doc/index.rst b/doc/index.rst index 57a0c95443..e6d8aa3fea 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -52,7 +52,6 @@ SpikeInterface is made of several modules to deal with different aspects of the overview get_started/index tutorials_custom_index - tutorials/index how_to/index modules/index api diff --git a/doc/tutorials_custom_index.rst b/doc/tutorials_custom_index.rst index 46a7bea630..4c7625d811 100644 --- a/doc/tutorials_custom_index.rst +++ b/doc/tutorials_custom_index.rst @@ -12,12 +12,14 @@ :orphan: -TutorialsNew +Tutorials ============ -Longer form tutorials about using SpikeInterface. Many of these are downloadable as notebooks or Python scripts so that you can "code along" with the tutorials. +Longer form tutorials about using SpikeInterface. Many of these are downloadable +as notebooks or Python scripts so that you can "code along" with the tutorials. -If you're new to SpikeInterface, we recommend trying out the :ref:`get_started/quickstart:Quickstart tutorial` first. +If you're new to SpikeInterface, we recommend trying out the +:ref:`get_started/quickstart:Quickstart tutorial` first. Updating from legacy -------------------- @@ -35,77 +37,53 @@ These tutorials focus on the :py:mod:`spikeinterface.core` module. .. grid:: 1 2 2 3 :gutter: 2 - .. grid-item-card:: + .. grid-item-card:: Recording objects :link-type: ref :link: sphx_glr_tutorials_core_plot_1_recording_extractor.py :img-top: /tutorials/core/images/thumb/sphx_glr_plot_1_recording_extractor_thumb.png :img-alt: Recording objects :class-card: gallery-card + :text-align: center - .. raw:: html - -
Recording objects
-
Manage loaded recordings in SpikeInterface
- - .. grid-item-card:: + .. grid-item-card:: Sorting objects :link-type: ref :link: sphx_glr_tutorials_core_plot_2_sorting_extractor.py :img-top: /tutorials/core/images/thumb/sphx_glr_plot_2_sorting_extractor_thumb.png :img-alt: Sorting objects :class-card: gallery-card + :text-align: center - .. raw:: html - -
Sorting objects
-
Explore sorting extractor features
- - .. grid-item-card:: + .. grid-item-card:: Handling probe information :link-type: ref :link: sphx_glr_tutorials_core_plot_3_handle_probe_info.py :img-top: /tutorials/core/images/thumb/sphx_glr_plot_3_handle_probe_info_thumb.png :img-alt: Handling probe information :class-card: gallery-card + :text-align: center - .. raw:: html - -
Handling probe information
-
Handle and visualize probe information
- - .. grid-item-card:: + .. grid-item-card:: SortingAnalyzer :link-type: ref :link: sphx_glr_tutorials_core_plot_4_sorting_analyzer.py :img-top: /tutorials/core/images/thumb/sphx_glr_plot_4_sorting_analyzer_thumb.png :img-alt: SortingAnalyzer :class-card: gallery-card + :text-align: center - .. raw:: html - -
SortingAnalyzer
-
Analyze sorting results with ease
- - .. grid-item-card:: + .. grid-item-card:: Append and/or concatenate segments :link-type: ref :link: sphx_glr_tutorials_core_plot_5_append_concatenate_segments.py :img-top: /tutorials/core/images/thumb/sphx_glr_plot_5_append_concatenate_segments_thumb.png :img-alt: Append/Concatenate segments :class-card: gallery-card + :text-align: center - .. raw:: html - -
Append and/or concatenate segments
-
Combine segments efficiently
- - .. grid-item-card:: + .. grid-item-card:: Handle time information :link-type: ref :link: sphx_glr_tutorials_core_plot_6_handle_times.py :img-top: /tutorials/core/images/thumb/sphx_glr_plot_6_handle_times_thumb.png :img-alt: Handle time information :class-card: gallery-card - - .. raw:: html - -
Handle time information
-
Manage and analyze time information
+ :text-align: center Extractors tutorials -------------------- @@ -115,29 +93,21 @@ The :py:mod:`spikeinterface.extractors` module is designed to load and save reco .. grid:: 1 2 2 3 :gutter: 2 - .. grid-item-card:: + .. grid-item-card:: Read various formats :link-type: ref :link: sphx_glr_tutorials_extractors_plot_1_read_various_formats.py :img-top: /tutorials/extractors/images/thumb/sphx_glr_plot_1_read_various_formats_thumb.png :img-alt: Read various formats :class-card: gallery-card + :text-align: center - .. raw:: html - -
Read various formats
-
Read different recording formats efficiently
- - .. grid-item-card:: + .. grid-item-card:: Working with unscaled traces :link-type: ref :link: sphx_glr_tutorials_extractors_plot_2_working_with_unscaled_traces.py :img-top: /tutorials/extractors/images/thumb/sphx_glr_plot_2_working_with_unscaled_traces_thumb.png :img-alt: Unscaled traces :class-card: gallery-card - - .. raw:: html - -
Working with unscaled traces
-
Learn about managing unscaled traces
+ :text-align: center Quality metrics tutorial ------------------------ @@ -147,29 +117,21 @@ The :code:`spikeinterface.qualitymetrics` module allows users to compute various .. grid:: 1 2 2 3 :gutter: 2 - .. grid-item-card:: + .. grid-item-card:: Quality Metrics :link-type: ref :link: sphx_glr_tutorials_qualitymetrics_plot_3_quality_mertics.py :img-top: /tutorials/qualitymetrics/images/thumb/sphx_glr_plot_3_quality_mertics_thumb.png :img-alt: Quality Metrics :class-card: gallery-card + :text-align: center - .. raw:: html - -
Quality Metrics
-
Evaluate sorting quality using metrics
- - .. grid-item-card:: + .. grid-item-card:: Curation Tutorial :link-type: ref :link: sphx_glr_tutorials_qualitymetrics_plot_4_curation.py :img-top: /tutorials/qualitymetrics/images/thumb/sphx_glr_plot_4_curation_thumb.png :img-alt: Curation Tutorial :class-card: gallery-card - - .. raw:: html - -
Curation Tutorial
-
Learn how to curate spike sorting data
+ :text-align: center Comparison tutorial ------------------- @@ -179,17 +141,13 @@ The :code:`spikeinterface.comparison` module allows you to compare sorter output .. grid:: 1 2 2 3 :gutter: 2 - .. grid-item-card:: + .. grid-item-card:: Sorter Comparison :link-type: ref :link: sphx_glr_tutorials_comparison_plot_5_comparison_sorter_weaknesses.py :img-top: /tutorials/comparison/images/thumb/sphx_glr_plot_5_comparison_sorter_weaknesses_thumb.png :img-alt: Sorter Comparison :class-card: gallery-card - - .. raw:: html - -
Sorter Comparison
-
Compare sorter outputs and assess weaknesses
+ :text-align: center Widgets tutorials ----------------- @@ -199,53 +157,37 @@ The :code:`widgets` module contains several plotting routines (widgets) for visu .. grid:: 1 2 2 3 :gutter: 2 - .. grid-item-card:: + .. grid-item-card:: RecordingExtractor Widgets :link-type: ref :link: sphx_glr_tutorials_widgets_plot_1_rec_gallery.py :img-top: /tutorials/widgets/images/thumb/sphx_glr_plot_1_rec_gallery_thumb.png :img-alt: Recording Widgets :class-card: gallery-card + :text-align: center - .. raw:: html - -
RecordingExtractor Widgets
-
Visualize recordings with widgets
- - .. grid-item-card:: + .. grid-item-card:: SortingExtractor Widgets :link-type: ref :link: sphx_glr_tutorials_widgets_plot_2_sort_gallery.py :img-top: /tutorials/widgets/images/thumb/sphx_glr_plot_2_sort_gallery_thumb.png :img-alt: Sorting Widgets :class-card: gallery-card + :text-align: center - .. raw:: html - -
SortingExtractor Widgets
-
Explore sorting data using widgets
- - .. grid-item-card:: + .. grid-item-card:: Waveforms Widgets :link-type: ref :link: sphx_glr_tutorials_widgets_plot_3_waveforms_gallery.py :img-top: /tutorials/widgets/images/thumb/sphx_glr_plot_3_waveforms_gallery_thumb.png :img-alt: Waveforms Widgets :class-card: gallery-card + :text-align: center - .. raw:: html - -
Waveforms Widgets
-
Display waveforms using SpikeInterface
- - .. grid-item-card:: + .. grid-item-card:: Peaks Widgets :link-type: ref :link: sphx_glr_tutorials_widgets_plot_4_peaks_gallery.py :img-top: /tutorials/widgets/images/thumb/sphx_glr_plot_4_peaks_gallery_thumb.png :img-alt: Peaks Widgets :class-card: gallery-card - - .. raw:: html - -
Peaks Widgets
-
Visualize detected peaks
+ :text-align: center Download All Examples --------------------- From 0e44185c2918a2d7b53cfe55879fde134c478b57 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 4 Nov 2024 11:05:26 +0100 Subject: [PATCH 25/42] Apply suggestions from code review --- src/spikeinterface/core/node_pipeline.py | 2 +- src/spikeinterface/sortingcomponents/peak_detection.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index 8ca4ba7f3a..53c2445c77 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -544,7 +544,7 @@ def run_node_pipeline( recording_slices : None | list[tuple] Optionaly give a list of slices to run the pipeline only on some chunks of the recording. It must be a list of (segment_index, frame_start, frame_stop). - If None (default), the entire recording is computed. + If None (default), the function iterates over the entire duration of the recording. Returns ------- diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py index 233b16dcf7..d03744f8f9 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection.py @@ -87,7 +87,7 @@ def detect_peaks( recording_slices : None | list[tuple] Optionaly give a list of slices to run the pipeline only on some chunks of the recording. It must be a list of (segment_index, frame_start, frame_stop). - If None (default), the entire recording is computed. + If None (default), the function iterates over the entire duration of the recording. {method_doc} {job_doc} From 2faed131014fe46fd0fafddcb9b94872f889ca7c Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 4 Nov 2024 16:05:42 +0100 Subject: [PATCH 26/42] Remove the need of template_similarity extension for autocorrelogram plot --- src/spikeinterface/exporters/report.py | 4 ++-- src/spikeinterface/widgets/autocorrelograms.py | 8 +++++++- src/spikeinterface/widgets/crosscorrelograms.py | 3 ++- src/spikeinterface/widgets/unit_summary.py | 4 ++-- 4 files changed, 13 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/exporters/report.py b/src/spikeinterface/exporters/report.py index 484da83342..ab08401382 100644 --- a/src/spikeinterface/exporters/report.py +++ b/src/spikeinterface/exporters/report.py @@ -22,8 +22,8 @@ def export_report( """ Exports a SI spike sorting report. The report includes summary figures of the spike sorting output. What is plotted depends on what has been calculated. Unit locations and unit waveforms are always included. - Unit waveform densities, correlograms and spike amplitudes are plotted if `waveforms`, `correlograms` - and 'template_similarity', and `spike_amplitudes` have been computed for the given `sorting_analyzer`. + Unit waveform densities, correlograms and spike amplitudes are plotted if `waveforms`, `correlograms`, + and `spike_amplitudes` have been computed for the given `sorting_analyzer`. Parameters ---------- diff --git a/src/spikeinterface/widgets/autocorrelograms.py b/src/spikeinterface/widgets/autocorrelograms.py index c8acd93dc2..c211a277f8 100644 --- a/src/spikeinterface/widgets/autocorrelograms.py +++ b/src/spikeinterface/widgets/autocorrelograms.py @@ -9,7 +9,13 @@ class AutoCorrelogramsWidget(CrossCorrelogramsWidget): # the doc is copied form CrossCorrelogramsWidget def __init__(self, *args, **kargs): - CrossCorrelogramsWidget.__init__(self, *args, **kargs) + _ = kargs.pop("min_similarity_for_correlograms", 0.0) + CrossCorrelogramsWidget.__init__( + self, + *args, + **kargs, + min_similarity_for_correlograms=None, + ) def plot_matplotlib(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt diff --git a/src/spikeinterface/widgets/crosscorrelograms.py b/src/spikeinterface/widgets/crosscorrelograms.py index cdb2041aa3..88dd803323 100644 --- a/src/spikeinterface/widgets/crosscorrelograms.py +++ b/src/spikeinterface/widgets/crosscorrelograms.py @@ -21,7 +21,8 @@ class CrossCorrelogramsWidget(BaseWidget): List of unit ids min_similarity_for_correlograms : float, default: 0.2 For sortingview backend. Threshold for computing pair-wise cross-correlograms. - If template similarity between two units is below this threshold, the cross-correlogram is not displayed + If template similarity between two units is below this threshold, the cross-correlogram is not displayed. + For auto-correlograms plot, this is automatically set to None. window_ms : float, default: 100.0 Window for CCGs in ms. If correlograms are already computed (e.g. with SortingAnalyzer), this argument is ignored diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index d8cbeb7bb3..9466110110 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -108,7 +108,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): fig = self.figure nrows = 2 ncols = 2 - if sorting_analyzer.has_extension("correlograms") and sorting_analyzer.has_extension("template_similarity"): + if sorting_analyzer.has_extension("correlograms"): ncols += 1 if sorting_analyzer.has_extension("waveforms"): ncols += 1 @@ -172,7 +172,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): col_counter += 1 ax_waveform_density.set_ylabel(None) - if sorting_analyzer.has_extension("correlograms") and sorting_analyzer.has_extension("template_similarity"): + if sorting_analyzer.has_extension("correlograms"): ax_correlograms = fig.add_subplot(gs[:2, col_counter]) AutoCorrelogramsWidget( sorting_analyzer, From 2ba37a8b3990af3919a3c1b294700909d144a457 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 5 Nov 2024 15:58:45 +0100 Subject: [PATCH 27/42] Don't let decimate mess with times and skim tests --- src/spikeinterface/preprocessing/decimate.py | 26 +++++++++---------- .../preprocessing/tests/test_decimate.py | 20 +++++++------- 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/src/spikeinterface/preprocessing/decimate.py b/src/spikeinterface/preprocessing/decimate.py index 334ebb02d2..c1b1cd9f80 100644 --- a/src/spikeinterface/preprocessing/decimate.py +++ b/src/spikeinterface/preprocessing/decimate.py @@ -63,18 +63,15 @@ def __init__( f"Consider combining DecimateRecording with FrameSliceRecording for fine control on the recording start/end frames." ) self._decimation_offset = decimation_offset - resample_rate = self._orig_samp_freq / self._decimation_factor + decimated_sampling_frequency = self._orig_samp_freq / self._decimation_factor - BasePreprocessor.__init__(self, recording, sampling_frequency=resample_rate) + BasePreprocessor.__init__(self, recording, sampling_frequency=decimated_sampling_frequency) - # in case there was a time_vector, it will be dropped for sanity. - # This is not necessary but consistent with ResampleRecording for parent_segment in recording._recording_segments: - parent_segment.time_vector = None self.add_recording_segment( DecimateRecordingSegment( parent_segment, - resample_rate, + decimated_sampling_frequency, self._orig_samp_freq, decimation_factor, decimation_offset, @@ -93,22 +90,25 @@ class DecimateRecordingSegment(BaseRecordingSegment): def __init__( self, parent_recording_segment, - resample_rate, + decimated_sampling_frequency, parent_rate, decimation_factor, decimation_offset, dtype, ): - if parent_recording_segment.t_start is None: - new_t_start = None + if parent_recording_segment.time_vector is not None: + time_vector = parent_recording_segment.time_vector[decimation_offset::decimation_factor] + decimated_sampling_frequency = None else: - new_t_start = parent_recording_segment.t_start + decimation_offset / parent_rate + time_vector = None + if parent_recording_segment.t_start is None: + t_start = None + else: + t_start = parent_recording_segment.t_start + decimation_offset / parent_rate # Do not use BasePreprocessorSegment bcause we have to reset the sampling rate! BaseRecordingSegment.__init__( - self, - sampling_frequency=resample_rate, - t_start=new_t_start, + self, sampling_frequency=decimated_sampling_frequency, t_start=t_start, time_vector=time_vector ) self._parent_segment = parent_recording_segment self._decimation_factor = decimation_factor diff --git a/src/spikeinterface/preprocessing/tests/test_decimate.py b/src/spikeinterface/preprocessing/tests/test_decimate.py index 100972f762..adfcbd0d4a 100644 --- a/src/spikeinterface/preprocessing/tests/test_decimate.py +++ b/src/spikeinterface/preprocessing/tests/test_decimate.py @@ -8,19 +8,19 @@ import numpy as np -@pytest.mark.parametrize("N_segments", [1, 2]) -@pytest.mark.parametrize("decimation_offset", [0, 1, 9, 10, 11, 100, 101]) -@pytest.mark.parametrize("decimation_factor", [1, 9, 10, 11, 100, 101]) +@pytest.mark.parametrize("num_segments", [1, 2]) +@pytest.mark.parametrize("decimation_offset", [0, 5, 21, 101]) +@pytest.mark.parametrize("decimation_factor", [1, 7, 50]) @pytest.mark.parametrize("start_frame", [0, 1, 5, None, 1000]) @pytest.mark.parametrize("end_frame", [0, 1, 5, None, 1000]) -def test_decimate(N_segments, decimation_offset, decimation_factor, start_frame, end_frame): +def test_decimate(num_segments, decimation_offset, decimation_factor, start_frame, end_frame): rec = generate_recording() - segment_num_samps = [101 + i for i in range(N_segments)] + segment_num_samps = [101 + i for i in range(num_segments)] rec = NumpyRecording([np.arange(2 * N).reshape(N, 2) for N in segment_num_samps], 1) - parent_traces = [rec.get_traces(i) for i in range(N_segments)] + parent_traces = [rec.get_traces(i) for i in range(num_segments)] if decimation_offset >= min(segment_num_samps) or decimation_offset >= decimation_factor: with pytest.raises(ValueError): @@ -28,14 +28,14 @@ def test_decimate(N_segments, decimation_offset, decimation_factor, start_frame, return decimated_rec = DecimateRecording(rec, decimation_factor, decimation_offset=decimation_offset) - decimated_parent_traces = [parent_traces[i][decimation_offset::decimation_factor] for i in range(N_segments)] + decimated_parent_traces = [parent_traces[i][decimation_offset::decimation_factor] for i in range(num_segments)] if start_frame is None: - start_frame = max(decimated_rec.get_num_samples(i) for i in range(N_segments)) + start_frame = max(decimated_rec.get_num_samples(i) for i in range(num_segments)) if end_frame is None: - end_frame = max(decimated_rec.get_num_samples(i) for i in range(N_segments)) + end_frame = max(decimated_rec.get_num_samples(i) for i in range(num_segments)) - for i in range(N_segments): + for i in range(num_segments): assert decimated_rec.get_num_samples(i) == decimated_parent_traces[i].shape[0] assert np.all( decimated_rec.get_traces(i, start_frame, end_frame) == decimated_parent_traces[i][start_frame:end_frame] From f90011803da2327d7ace74ff2a35b91b30c70d32 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 5 Nov 2024 16:35:27 +0100 Subject: [PATCH 28/42] More skimming and test decimate with times --- src/spikeinterface/preprocessing/decimate.py | 1 + .../preprocessing/tests/test_decimate.py | 57 +++++++++++++++---- 2 files changed, 47 insertions(+), 11 deletions(-) diff --git a/src/spikeinterface/preprocessing/decimate.py b/src/spikeinterface/preprocessing/decimate.py index c1b1cd9f80..2b47601fc2 100644 --- a/src/spikeinterface/preprocessing/decimate.py +++ b/src/spikeinterface/preprocessing/decimate.py @@ -99,6 +99,7 @@ def __init__( if parent_recording_segment.time_vector is not None: time_vector = parent_recording_segment.time_vector[decimation_offset::decimation_factor] decimated_sampling_frequency = None + t_start = None else: time_vector = None if parent_recording_segment.t_start is None: diff --git a/src/spikeinterface/preprocessing/tests/test_decimate.py b/src/spikeinterface/preprocessing/tests/test_decimate.py index adfcbd0d4a..dd521cbe9b 100644 --- a/src/spikeinterface/preprocessing/tests/test_decimate.py +++ b/src/spikeinterface/preprocessing/tests/test_decimate.py @@ -11,13 +11,8 @@ @pytest.mark.parametrize("num_segments", [1, 2]) @pytest.mark.parametrize("decimation_offset", [0, 5, 21, 101]) @pytest.mark.parametrize("decimation_factor", [1, 7, 50]) -@pytest.mark.parametrize("start_frame", [0, 1, 5, None, 1000]) -@pytest.mark.parametrize("end_frame", [0, 1, 5, None, 1000]) -def test_decimate(num_segments, decimation_offset, decimation_factor, start_frame, end_frame): - rec = generate_recording() - - segment_num_samps = [101 + i for i in range(num_segments)] - +def test_decimate(num_segments, decimation_offset, decimation_factor): + segment_num_samps = [20000, 40000] rec = NumpyRecording([np.arange(2 * N).reshape(N, 2) for N in segment_num_samps], 1) parent_traces = [rec.get_traces(i) for i in range(num_segments)] @@ -30,10 +25,19 @@ def test_decimate(num_segments, decimation_offset, decimation_factor, start_fram decimated_rec = DecimateRecording(rec, decimation_factor, decimation_offset=decimation_offset) decimated_parent_traces = [parent_traces[i][decimation_offset::decimation_factor] for i in range(num_segments)] - if start_frame is None: - start_frame = max(decimated_rec.get_num_samples(i) for i in range(num_segments)) - if end_frame is None: - end_frame = max(decimated_rec.get_num_samples(i) for i in range(num_segments)) + for start_frame in [0, 1, 5, None, 1000]: + for end_frame in [0, 1, 5, None, 1000]: + if start_frame is None: + start_frame = max(decimated_rec.get_num_samples(i) for i in range(num_segments)) + if end_frame is None: + end_frame = max(decimated_rec.get_num_samples(i) for i in range(num_segments)) + + for i in range(num_segments): + assert decimated_rec.get_num_samples(i) == decimated_parent_traces[i].shape[0] + assert np.all( + decimated_rec.get_traces(i, start_frame, end_frame) + == decimated_parent_traces[i][start_frame:end_frame] + ) for i in range(num_segments): assert decimated_rec.get_num_samples(i) == decimated_parent_traces[i].shape[0] @@ -42,5 +46,36 @@ def test_decimate(num_segments, decimation_offset, decimation_factor, start_fram ) +def test_decimate_with_times(): + rec = generate_recording(durations=[5, 10]) + + # test with times + times = [rec.get_times(0) + 10, rec.get_times(1) + 20] + for i, t in enumerate(times): + rec.set_times(t, i) + + decimation_factor = 2 + decimation_offset = 1 + decimated_rec = DecimateRecording(rec, decimation_factor, decimation_offset=decimation_offset) + + for segment_index in range(rec.get_num_segments()): + assert np.allclose( + decimated_rec.get_times(segment_index), + rec.get_times(segment_index)[decimation_offset::decimation_factor], + ) + + # test with t_start + rec = generate_recording(durations=[5, 10]) + t_starts = [10, 20] + for t_start, rec_segment in zip(t_starts, rec._recording_segments): + rec_segment.t_start = t_start + decimated_rec = DecimateRecording(rec, decimation_factor, decimation_offset=decimation_offset) + for segment_index in range(rec.get_num_segments()): + assert np.allclose( + decimated_rec.get_times(segment_index), + rec.get_times(segment_index)[decimation_offset::decimation_factor], + ) + + if __name__ == "__main__": test_decimate() From 2d843f8770a8587c32920d3af4dcc54bb8c05411 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 6 Nov 2024 10:09:56 +0100 Subject: [PATCH 29/42] Zach's comments --- src/spikeinterface/preprocessing/decimate.py | 2 +- src/spikeinterface/preprocessing/tests/test_decimate.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/preprocessing/decimate.py b/src/spikeinterface/preprocessing/decimate.py index 2b47601fc2..d5fc9d2025 100644 --- a/src/spikeinterface/preprocessing/decimate.py +++ b/src/spikeinterface/preprocessing/decimate.py @@ -105,7 +105,7 @@ def __init__( if parent_recording_segment.t_start is None: t_start = None else: - t_start = parent_recording_segment.t_start + decimation_offset / parent_rate + t_start = parent_recording_segment.t_start + (decimation_offset / parent_rate) # Do not use BasePreprocessorSegment bcause we have to reset the sampling rate! BaseRecordingSegment.__init__( diff --git a/src/spikeinterface/preprocessing/tests/test_decimate.py b/src/spikeinterface/preprocessing/tests/test_decimate.py index dd521cbe9b..aab17560a6 100644 --- a/src/spikeinterface/preprocessing/tests/test_decimate.py +++ b/src/spikeinterface/preprocessing/tests/test_decimate.py @@ -9,7 +9,7 @@ @pytest.mark.parametrize("num_segments", [1, 2]) -@pytest.mark.parametrize("decimation_offset", [0, 5, 21, 101]) +@pytest.mark.parametrize("decimation_offset", [0, 1, 5, 21, 101]) @pytest.mark.parametrize("decimation_factor", [1, 7, 50]) def test_decimate(num_segments, decimation_offset, decimation_factor): segment_num_samps = [20000, 40000] From d6b4c1e7474c372c6d9f71787ddbe707854bd11f Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 7 Nov 2024 11:44:13 +0100 Subject: [PATCH 30/42] Fix cbin_file_path --- src/spikeinterface/extractors/cbin_ibl.py | 30 +++++++++++++++-------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/src/spikeinterface/extractors/cbin_ibl.py b/src/spikeinterface/extractors/cbin_ibl.py index d7e5b58e11..88e1029ab0 100644 --- a/src/spikeinterface/extractors/cbin_ibl.py +++ b/src/spikeinterface/extractors/cbin_ibl.py @@ -1,6 +1,7 @@ from __future__ import annotations from pathlib import Path +import warnings import numpy as np import probeinterface @@ -30,8 +31,10 @@ class CompressedBinaryIblExtractor(BaseRecording): stream_name : {"ap", "lp"}, default: "ap". Whether to load AP or LFP band, one of "ap" or "lp". - cbin_file : str or None, default None + cbin_file_path : str or None, default None The cbin file of the recording. If None, searches in `folder_path` for file. + cbin_file : str or None, default None + (deprecated) The cbin file of the recording. If None, searches in `folder_path` for file. Returns ------- @@ -41,14 +44,21 @@ class CompressedBinaryIblExtractor(BaseRecording): installation_mesg = "To use the CompressedBinaryIblExtractor, install mtscomp: \n\n pip install mtscomp\n\n" - def __init__(self, folder_path=None, load_sync_channel=False, stream_name="ap", cbin_file=None): + def __init__( + self, folder_path=None, load_sync_channel=False, stream_name="ap", cbin_file_path=None, cbin_file=None + ): from neo.rawio.spikeglxrawio import read_meta_file try: import mtscomp except ImportError: raise ImportError(self.installation_mesg) - if cbin_file is None: + if cbin_file is not None: + warnings.warn( + "The `cbin_file` argument is deprecated, please use `cbin_file_path` instead", DeprecationWarning + ) + cbin_file_path = cbin_file + if cbin_file_path is None: folder_path = Path(folder_path) # check bands assert stream_name in ["ap", "lp"], "stream_name must be one of: 'ap', 'lp'" @@ -60,17 +70,17 @@ def __init__(self, folder_path=None, load_sync_channel=False, stream_name="ap", assert ( len(curr_cbin_files) == 1 ), f"There should only be one `*.cbin` file in the folder, but {print(curr_cbin_files)} have been found" - cbin_file = curr_cbin_files[0] + cbin_file_path = curr_cbin_files[0] else: - cbin_file = Path(cbin_file) - folder_path = cbin_file.parent + cbin_file_path = Path(cbin_file_path) + folder_path = cbin_file_path.parent - ch_file = cbin_file.with_suffix(".ch") - meta_file = cbin_file.with_suffix(".meta") + ch_file = cbin_file_path.with_suffix(".ch") + meta_file = cbin_file_path.with_suffix(".meta") # reader cbuffer = mtscomp.Reader() - cbuffer.open(cbin_file, ch_file) + cbuffer.open(cbin_file_path, ch_file) # meta data meta = read_meta_file(meta_file) @@ -119,7 +129,7 @@ def __init__(self, folder_path=None, load_sync_channel=False, stream_name="ap", self._kwargs = { "folder_path": str(Path(folder_path).resolve()), "load_sync_channel": load_sync_channel, - "cbin_file": str(Path(cbin_file).resolve()), + "cbin_file_path": str(Path(cbin_file_path).resolve()), } From e6f45056852e181fb8d6909c8a3365a08cb2c8f5 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 7 Nov 2024 15:56:48 +0100 Subject: [PATCH 31/42] Update src/spikeinterface/extractors/cbin_ibl.py --- src/spikeinterface/extractors/cbin_ibl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/extractors/cbin_ibl.py b/src/spikeinterface/extractors/cbin_ibl.py index 88e1029ab0..357afde04e 100644 --- a/src/spikeinterface/extractors/cbin_ibl.py +++ b/src/spikeinterface/extractors/cbin_ibl.py @@ -55,7 +55,7 @@ def __init__( raise ImportError(self.installation_mesg) if cbin_file is not None: warnings.warn( - "The `cbin_file` argument is deprecated, please use `cbin_file_path` instead", DeprecationWarning + "The `cbin_file` argument is deprecated, please use `cbin_file_path` instead", DeprecationWarning, stacklevel=2 ) cbin_file_path = cbin_file if cbin_file_path is None: From 471ce724faac7245766538880b7fcd196f49fa30 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 7 Nov 2024 14:57:14 +0000 Subject: [PATCH 32/42] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/extractors/cbin_ibl.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/extractors/cbin_ibl.py b/src/spikeinterface/extractors/cbin_ibl.py index 357afde04e..8fe19f3d7e 100644 --- a/src/spikeinterface/extractors/cbin_ibl.py +++ b/src/spikeinterface/extractors/cbin_ibl.py @@ -55,7 +55,9 @@ def __init__( raise ImportError(self.installation_mesg) if cbin_file is not None: warnings.warn( - "The `cbin_file` argument is deprecated, please use `cbin_file_path` instead", DeprecationWarning, stacklevel=2 + "The `cbin_file` argument is deprecated, please use `cbin_file_path` instead", + DeprecationWarning, + stacklevel=2, ) cbin_file_path = cbin_file if cbin_file_path is None: From f293303bef50073cb71add06e94410635615384b Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 8 Nov 2024 16:06:30 +0100 Subject: [PATCH 33/42] Removing useless dependencies --- src/spikeinterface/postprocessing/template_metrics.py | 2 +- src/spikeinterface/postprocessing/template_similarity.py | 2 +- src/spikeinterface/postprocessing/unit_locations.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 306e9594b8..6e7bcf21b8 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -97,7 +97,7 @@ class ComputeTemplateMetrics(AnalyzerExtension): extension_name = "template_metrics" depend_on = ["templates"] - need_recording = True + need_recording = False use_nodepipeline = False need_job_kwargs = False diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index cfa9d89fea..6c30e2730b 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -44,7 +44,7 @@ class ComputeTemplateSimilarity(AnalyzerExtension): extension_name = "template_similarity" depend_on = ["templates"] - need_recording = True + need_recording = False use_nodepipeline = False need_job_kwargs = False need_backward_compatibility_on_load = True diff --git a/src/spikeinterface/postprocessing/unit_locations.py b/src/spikeinterface/postprocessing/unit_locations.py index 4029fc88c7..3f6dd47eec 100644 --- a/src/spikeinterface/postprocessing/unit_locations.py +++ b/src/spikeinterface/postprocessing/unit_locations.py @@ -39,7 +39,7 @@ class ComputeUnitLocations(AnalyzerExtension): extension_name = "unit_locations" depend_on = ["templates"] - need_recording = True + need_recording = False use_nodepipeline = False need_job_kwargs = False need_backward_compatibility_on_load = True From 620f8013b8bf4f1332a7802dd3f6804ce068493c Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 13 Nov 2024 13:37:50 +0000 Subject: [PATCH 34/42] Apply to all segments if 'segment_index' is 'None'. --- src/spikeinterface/core/baserecording.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index b8a0420794..7392caa69b 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -521,13 +521,20 @@ def shift_start_time(self, shift, segment_index=None): segment_index : int | None The segment on which to shift the times. """ - segment_index = self._check_segment_index(segment_index) - rs = self._recording_segments[segment_index] + self._check_segment_index(segment_index) - if self.has_time_vector(): - rs.time_vector += shift + if segment_index is None: + segments_to_shift = range(self.get_num_segments()) else: - rs.t_start += shift + segments_to_shift = (segment_index,) + + for idx in segments_to_shift: + rs = self._recording_segments[idx] + + if self.has_time_vector(): + rs.time_vector += shift + else: + rs.t_start += shift def sample_index_to_time(self, sample_ind, segment_index=None): """ From 22d5dfc2a552e00d7b55d7c28681e25a1f51a711 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 13 Nov 2024 13:39:34 +0000 Subject: [PATCH 35/42] Add type hints. --- src/spikeinterface/core/baserecording.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 7392caa69b..0af9c4bb6a 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -509,7 +509,7 @@ def reset_times(self): rs.t_start = None rs.sampling_frequency = self.sampling_frequency - def shift_start_time(self, shift, segment_index=None): + def shift_start_time(self, shift: int | float, segment_index: int | None = None) -> None: """ Shift the starting time of the times. @@ -536,15 +536,14 @@ def shift_start_time(self, shift, segment_index=None): else: rs.t_start += shift - def sample_index_to_time(self, sample_ind, segment_index=None): - """ - Transform sample index into time in seconds - """ + def sample_index_to_time(self, sample_ind: int, segment_index: int | None = None): + """ """ segment_index = self._check_segment_index(segment_index) rs = self._recording_segments[segment_index] return rs.sample_index_to_time(sample_ind) - def time_to_sample_index(self, time_s, segment_index=None): + def time_to_sample_index(self, time_s: float, segment_index: int | None = None): + """ """ segment_index = self._check_segment_index(segment_index) rs = self._recording_segments[segment_index] return rs.time_to_sample_index(time_s) From 458a3dcc201380740583ef1f075951e83ee77ed8 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 13 Nov 2024 13:43:45 +0000 Subject: [PATCH 36/42] Update name and docstring. --- src/spikeinterface/core/baserecording.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 0af9c4bb6a..91f99f17b0 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -509,19 +509,24 @@ def reset_times(self): rs.t_start = None rs.sampling_frequency = self.sampling_frequency - def shift_start_time(self, shift: int | float, segment_index: int | None = None) -> None: + def shift_times(self, shift: int | float, segment_index: int | None = None) -> None: """ - Shift the starting time of the times. + Shift all times by a scalar value. The default behaviour is to + shift all segments uniformly. + Parameters + ---------- shift : int | float - The shift to apply to the first time point. If positive, - the current start time will be increased by `shift`. If - negative, the start time will be decreased. + The shift to apply. If positive, times will be increased by `shift`. + e.g. shifting by 1 will be like the recording started 1 second later. + If negative, the start time will be decreased i.e. as if the recording + started earlier. segment_index : int | None - The segment on which to shift the times. + The segment on which to shift the times. if `None`, all + segments will be shifted. """ - self._check_segment_index(segment_index) + self._check_segment_index(segment_index) # Check the segment index is valid only if segment_index is None: segments_to_shift = range(self.get_num_segments()) From 8845d3d7eb6caad8c6a5f0c12842f480766d3a26 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Wed, 13 Nov 2024 14:40:16 +0000 Subject: [PATCH 37/42] Add verbose kwarg to mda write_recording --- src/spikeinterface/extractors/mdaextractors.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/spikeinterface/extractors/mdaextractors.py b/src/spikeinterface/extractors/mdaextractors.py index f055e1d7c9..d2886d9e79 100644 --- a/src/spikeinterface/extractors/mdaextractors.py +++ b/src/spikeinterface/extractors/mdaextractors.py @@ -72,6 +72,7 @@ def write_recording( params_fname="params.json", geom_fname="geom.csv", dtype=None, + verbose=False, **job_kwargs, ): """Write a recording to file in MDA format. @@ -93,6 +94,8 @@ def write_recording( File name of geom file dtype : dtype or None, default: None Data type to be used. If None dtype is same as recording traces. + verbose : bool + If True, shows progress bar when saving recording. **job_kwargs: Use by job_tools modules to set: @@ -130,6 +133,7 @@ def write_recording( dtype=dtype, byte_offset=header_size, add_file_extension=False, + verbose=verbose, **job_kwargs, ) From 3e98c670a27671590613b7c1c4118780a8c47ce8 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 13 Nov 2024 18:32:48 +0000 Subject: [PATCH 38/42] Add tests. --- .../core/tests/test_time_handling.py | 92 ++++++++++++++++++- 1 file changed, 89 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/tests/test_time_handling.py b/src/spikeinterface/core/tests/test_time_handling.py index a129316ee7..9b7ed11bbb 100644 --- a/src/spikeinterface/core/tests/test_time_handling.py +++ b/src/spikeinterface/core/tests/test_time_handling.py @@ -15,7 +15,10 @@ class TestTimeHandling: is generated on the fly. Both time representations are tested here. """ - # Fixtures ##### + # ######################################################################### + # Fixtures + # ######################################################################### + @pytest.fixture(scope="session") def time_vector_recording(self): """ @@ -95,7 +98,10 @@ def _get_fixture_data(self, request, fixture_name): raw_recording, times_recording, all_times = time_recording_fixture return (raw_recording, times_recording, all_times) - # Tests ##### + # ######################################################################### + # Tests + # ######################################################################### + def test_has_time_vector(self, time_vector_recording): """ Test the `has_time_vector` function returns `False` before @@ -305,7 +311,87 @@ def test_sorting_analyzer_get_durations_no_recording(self, time_vector_recording assert np.array_equal(sorting_analyzer.get_total_duration(), raw_recording.get_total_duration()) - # Helpers #### + @pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"]) + @pytest.mark.parametrize("shift", [-123.456, 123.456]) + def test_shift_time_all_segments(self, request, fixture_name, shift): + """ + Shift the times in every segment using the `None` default, then + check that every segment of the recording is shifted as expected. + """ + _, times_recording, all_times = self._get_fixture_data(request, fixture_name) + + num_segments, orig_seg_data = self._store_all_times(times_recording) + + times_recording.shift_times(shift) # use default `segment_index=None` + + for idx in range(num_segments): + assert np.allclose( + orig_seg_data[idx], times_recording.get_times(segment_index=idx) - shift, rtol=0, atol=1e-8 + ) + + @pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"]) + @pytest.mark.parametrize("shift", [-123.456, 123.456]) + def test_shift_times_different_segments(self, request, fixture_name, shift): + """ + Shift each segment separately, and check the shifted segment only + is shifted as expected. + """ + _, times_recording, all_times = self._get_fixture_data(request, fixture_name) + + num_segments, orig_seg_data = self._store_all_times(times_recording) + + # For each segment, shift the segment only and check the + # times are updated as expected. + for idx in range(num_segments): + + scaler = idx + 2 + times_recording.shift_times(shift * scaler, segment_index=idx) + + assert np.allclose( + orig_seg_data[idx], times_recording.get_times(segment_index=idx) - shift * scaler, rtol=0, atol=1e-8 + ) + + # Just do a little check that we are not + # accidentally changing some other segments, + # which should remain unchanged at this point in the loop. + if idx != num_segments - 1: + assert np.array_equal(orig_seg_data[idx + 1], times_recording.get_times(segment_index=idx + 1)) + + @pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"]) + def test_save_and_load_time_shift(self, request, fixture_name, tmp_path): + """ + Save the shifted data and check the shift is propagated correctly. + """ + _, times_recording, all_times = self._get_fixture_data(request, fixture_name) + + shift = 100 + times_recording.shift_times(shift=shift) + + times_recording.save(folder=tmp_path / "my_file") + + loaded_recording = si.load_extractor(tmp_path / "my_file") + + for idx in range(times_recording.get_num_segments()): + assert np.array_equal( + times_recording.get_times(segment_index=idx), loaded_recording.get_times(segment_index=idx) + ) + + def _store_all_times(self, recording): + """ + Convenience function to store original times of all segments to a dict. + """ + num_segments = recording.get_num_segments() + seg_data = {} + + for idx in range(num_segments): + seg_data[idx] = copy.deepcopy(recording.get_times(segment_index=idx)) + + return num_segments, seg_data + + # ######################################################################### + # Helpers + # ######################################################################### + def _check_times_match(self, recording, all_times): """ For every segment in a recording, check the `get_times()` From 4d7246a529e3d17747cf5a496a0a04bd97f4eb09 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 13 Nov 2024 18:33:17 +0000 Subject: [PATCH 39/42] Fixes on shift function. --- src/spikeinterface/core/baserecording.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 91f99f17b0..4b545dc7c7 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -526,8 +526,6 @@ def shift_times(self, shift: int | float, segment_index: int | None = None) -> N The segment on which to shift the times. if `None`, all segments will be shifted. """ - self._check_segment_index(segment_index) # Check the segment index is valid only - if segment_index is None: segments_to_shift = range(self.get_num_segments()) else: @@ -536,7 +534,7 @@ def shift_times(self, shift: int | float, segment_index: int | None = None) -> N for idx in segments_to_shift: rs = self._recording_segments[idx] - if self.has_time_vector(): + if self.has_time_vector(segment_index=idx): rs.time_vector += shift else: rs.t_start += shift From a1cf3367d18a549281208b25c622f2a1ee773226 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 13 Nov 2024 18:35:32 +0000 Subject: [PATCH 40/42] Undo out of scope changes. --- src/spikeinterface/core/baserecording.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 4b545dc7c7..886f7db79f 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -539,14 +539,15 @@ def shift_times(self, shift: int | float, segment_index: int | None = None) -> N else: rs.t_start += shift - def sample_index_to_time(self, sample_ind: int, segment_index: int | None = None): - """ """ + def sample_index_to_time(self, sample_ind, segment_index=None): + """ + Transform sample index into time in seconds + """ segment_index = self._check_segment_index(segment_index) rs = self._recording_segments[segment_index] return rs.sample_index_to_time(sample_ind) - def time_to_sample_index(self, time_s: float, segment_index: int | None = None): - """ """ + def time_to_sample_index(self, time_s, segment_index=None): segment_index = self._check_segment_index(segment_index) rs = self._recording_segments[segment_index] return rs.time_to_sample_index(time_s) From 469b3b0e36fdbc0571d37e100d99d6c741af1377 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 13 Nov 2024 18:37:20 +0000 Subject: [PATCH 41/42] Fix docstring. --- src/spikeinterface/core/baserecording.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 886f7db79f..6d9d2a827f 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -511,8 +511,7 @@ def reset_times(self): def shift_times(self, shift: int | float, segment_index: int | None = None) -> None: """ - Shift all times by a scalar value. The default behaviour is to - shift all segments uniformly. + Shift all times by a scalar value. Parameters ---------- @@ -523,8 +522,8 @@ def shift_times(self, shift: int | float, segment_index: int | None = None) -> N started earlier. segment_index : int | None - The segment on which to shift the times. if `None`, all - segments will be shifted. + The segment on which to shift the times. + If `None`, all segments will be shifted. """ if segment_index is None: segments_to_shift = range(self.get_num_segments()) From 1e53a5e06b2a90956d72150826a8f590d673b5ce Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 19 Nov 2024 17:45:40 +0100 Subject: [PATCH 42/42] Update src/spikeinterface/extractors/cbin_ibl.py Co-authored-by: Heberto Mayorquin --- src/spikeinterface/extractors/cbin_ibl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/extractors/cbin_ibl.py b/src/spikeinterface/extractors/cbin_ibl.py index 8fe19f3d7e..728d352973 100644 --- a/src/spikeinterface/extractors/cbin_ibl.py +++ b/src/spikeinterface/extractors/cbin_ibl.py @@ -31,7 +31,7 @@ class CompressedBinaryIblExtractor(BaseRecording): stream_name : {"ap", "lp"}, default: "ap". Whether to load AP or LFP band, one of "ap" or "lp". - cbin_file_path : str or None, default None + cbin_file_path : str, Path or None, default None The cbin file of the recording. If None, searches in `folder_path` for file. cbin_file : str or None, default None (deprecated) The cbin file of the recording. If None, searches in `folder_path` for file.