diff --git a/doc/modules/core.rst b/doc/modules/core.rst index 8aa1815a55..5df9a7e6b1 100644 --- a/doc/modules/core.rst +++ b/doc/modules/core.rst @@ -385,7 +385,7 @@ and merging unit groups. sorting_analyzer_select = sorting_analyzer.select_units(unit_ids=[0, 1, 2, 3]) sorting_analyzer_remove = sorting_analyzer.remove_units(remove_unit_ids=[0]) - sorting_analyzer_merge = sorting_analyzer.merge_units([0, 1], [2, 3]) + sorting_analyzer_merge = sorting_analyzer.merge_units([[0, 1], [2, 3]]) All computed extensions will be automatically propagated or merged when curating. Please refer to the :ref:`modules/curation:Curation module` documentation for more information. diff --git a/doc/modules/curation.rst b/doc/modules/curation.rst index d115b33e4a..37de992806 100644 --- a/doc/modules/curation.rst +++ b/doc/modules/curation.rst @@ -88,7 +88,7 @@ The ``censored_period_ms`` parameter is the time window in milliseconds to consi The :py:func:`~spikeinterface.curation.remove_redundand_units` function removes redundant units from the sorting output. Redundant units are units that share over a certain percentage of spikes, by default 80%. -The function can acto both on a ``BaseSorting`` or a ``SortingAnalyzer`` object. +The function can act both on a ``BaseSorting`` or a ``SortingAnalyzer`` object. .. code-block:: python @@ -102,13 +102,18 @@ The function can acto both on a ``BaseSorting`` or a ``SortingAnalyzer`` object. ) # remove redundant units from SortingAnalyzer object - clean_sorting_analyzer = remove_redundant_units( + # note this returns a cleaned sorting + clean_sorting = remove_redundant_units( sorting_analyzer, duplicate_threshold=0.9, remove_strategy="min_shift" ) + # in order to have a SortingAnalyer with only the non-redundant units one must + # select the designed units remembering to give format and folder if one wants + # a persistent SortingAnalyzer. + clean_sorting_analyzer = sorting_analyzer.select_units(clean_sorting.unit_ids) -We recommend usinf the ``SortingAnalyzer`` approach, since the ``min_shift`` strategy keeps +We recommend using the ``SortingAnalyzer`` approach, since the ``min_shift`` strategy keeps the unit (among the redundant ones), with a better template alignment. diff --git a/pyproject.toml b/pyproject.toml index a43ab63c8e..fc09ad9198 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ authors = [ ] description = "Python toolkit for analysis, visualization, and comparison of spike sorting output" readme = "README.md" -requires-python = ">=3.9,<4.0" +requires-python = ">=3.9,<3.13" # Only numpy 2.1 supported on python 3.13 for windows. We need to wait for fix on neo classifiers = [ "Programming Language :: Python :: 3 :: Only", "License :: OSI Approved :: MIT License", diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index bc5de63d07..447bbe562e 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -22,21 +22,23 @@ class ComputeRandomSpikes(AnalyzerExtension): """ - AnalyzerExtension that select some random spikes. + AnalyzerExtension that select somes random spikes. + This allows for a subsampling of spikes for further calculations and is important + for managing that amount of memory and speed of computation in the analyzer. This will be used by the `waveforms`/`templates` extensions. - This internally use `random_spikes_selection()` parameters are the same. + This internally uses `random_spikes_selection()` parameters. Parameters ---------- - method: "uniform" | "all", default: "uniform" + method : "uniform" | "all", default: "uniform" The method to select the spikes - max_spikes_per_unit: int, default: 500 + max_spikes_per_unit : int, default: 500 The maximum number of spikes per unit, ignored if method="all" - margin_size: int, default: None + margin_size : int, default: None A margin on each border of segments to avoid border spikes, ignored if method="all" - seed: int or None, default: None + seed : int or None, default: None A seed for the random generator, ignored if method="all" Returns @@ -104,7 +106,7 @@ def get_random_spikes(self): return self._some_spikes def get_selected_indices_in_spike_train(self, unit_id, segment_index): - # usefull for Waveforms extractor backwars compatibility + # useful for WaveformExtractor backwards compatibility # In Waveforms extractor "selected_spikes" was a dict (key: unit_id) of list (segment_index) of indices of spikes in spiketrain sorting = self.sorting_analyzer.sorting random_spikes_indices = self.data["random_spikes_indices"] @@ -133,16 +135,16 @@ class ComputeWaveforms(AnalyzerExtension): Parameters ---------- - ms_before: float, default: 1.0 + ms_before : float, default: 1.0 The number of ms to extract before the spike events - ms_after: float, default: 2.0 + ms_after : float, default: 2.0 The number of ms to extract after the spike events - dtype: None | dtype, default: None + dtype : None | dtype, default: None The dtype of the waveforms. If None, the dtype of the recording is used. Returns ------- - waveforms: np.ndarray + waveforms : np.ndarray Array with computed waveforms with shape (num_random_spikes, num_samples, num_channels) """ @@ -380,7 +382,12 @@ def _set_params(self, ms_before: float = 1.0, ms_after: float = 2.0, operators=N assert isinstance(operators, list) for operator in operators: if isinstance(operator, str): - assert operator in ("average", "std", "median", "mad") + if operator not in ("average", "std", "median", "mad"): + error_msg = ( + f"You have entered an operator {operator} in your `operators` argument which is " + f"not supported. Please use any of ['average', 'std', 'median', 'mad'] instead." + ) + raise ValueError(error_msg) else: assert isinstance(operator, (list, tuple)) assert len(operator) == 2 @@ -405,9 +412,13 @@ def _run(self, verbose=False, **job_kwargs): self._compute_and_append_from_waveforms(self.params["operators"]) else: - for operator in self.params["operators"]: - if operator not in ("average", "std"): - raise ValueError(f"Computing templates with operators {operator} needs the 'waveforms' extension") + bad_operator_list = [ + operator for operator in self.params["operators"] if operator not in ("average", "std") + ] + if len(bad_operator_list) > 0: + raise ValueError( + f"Computing templates with operators {bad_operator_list} requires the 'waveforms' extension" + ) recording = self.sorting_analyzer.recording sorting = self.sorting_analyzer.sorting @@ -441,7 +452,7 @@ def _run(self, verbose=False, **job_kwargs): def _compute_and_append_from_waveforms(self, operators): if not self.sorting_analyzer.has_extension("waveforms"): - raise ValueError(f"Computing templates with operators {operators} needs the 'waveforms' extension") + raise ValueError(f"Computing templates with operators {operators} requires the 'waveforms' extension") unit_ids = self.sorting_analyzer.unit_ids channel_ids = self.sorting_analyzer.channel_ids @@ -466,7 +477,7 @@ def _compute_and_append_from_waveforms(self, operators): assert self.sorting_analyzer.has_extension( "random_spikes" - ), "compute templates requires the random_spikes extension. You can run sorting_analyzer.get_random_spikes()" + ), "compute 'templates' requires the random_spikes extension. You can run sorting_analyzer.compute('random_spikes')" some_spikes = self.sorting_analyzer.get_extension("random_spikes").get_random_spikes() for unit_index, unit_id in enumerate(unit_ids): spike_mask = some_spikes["unit_index"] == unit_index @@ -549,9 +560,17 @@ def _get_data(self, operator="average", percentile=None, outputs="numpy"): if operator != "percentile": key = operator else: - assert percentile is not None, "You must provide percentile=..." + assert percentile is not None, "You must provide percentile=... if `operator=percentile`" key = f"percentile_{percentile}" + if key not in self.data.keys(): + error_msg = ( + f"You have entered `operator={key}`, but the only operators calculated are " + f"{list(self.data.keys())}. Please use one of these as your `operator` in the " + f"`get_data` function." + ) + raise ValueError(error_msg) + templates_array = self.data[key] if outputs == "numpy": @@ -566,7 +585,7 @@ def _get_data(self, operator="average", percentile=None, outputs="numpy"): probe=self.sorting_analyzer.get_probe(), ) else: - raise ValueError("outputs must be numpy or Templates") + raise ValueError("outputs must be `numpy` or `Templates`") def get_templates(self, unit_ids=None, operator="average", percentile=None, save=True, outputs="numpy"): """ @@ -576,26 +595,26 @@ def get_templates(self, unit_ids=None, operator="average", percentile=None, save Parameters ---------- - unit_ids: list or None + unit_ids : list or None Unit ids to retrieve waveforms for - operator: "average" | "median" | "std" | "percentile", default: "average" + operator : "average" | "median" | "std" | "percentile", default: "average" The operator to compute the templates - percentile: float, default: None + percentile : float, default: None Percentile to use for operator="percentile" - save: bool, default True + save : bool, default: True In case, the operator is not computed yet it can be saved to folder or zarr - outputs: "numpy" | "Templates" + outputs : "numpy" | "Templates", default: "numpy" Whether to return a numpy array or a Templates object Returns ------- - templates: np.array + templates : np.array | Templates The returned templates (num_units, num_samples, num_channels) """ if operator != "percentile": key = operator else: - assert percentile is not None, "You must provide percentile=..." + assert percentile is not None, "You must provide percentile=... if `operator='percentile'`" key = f"pencentile_{percentile}" if key in self.data: @@ -632,7 +651,7 @@ def get_templates(self, unit_ids=None, operator="average", percentile=None, save is_scaled=self.sorting_analyzer.return_scaled, ) else: - raise ValueError("outputs must be numpy or Templates") + raise ValueError("`outputs` must be 'numpy' or 'Templates'") def get_unit_template(self, unit_id, operator="average"): """ @@ -642,7 +661,7 @@ def get_unit_template(self, unit_id, operator="average"): ---------- unit_id: str | int Unit id to retrieve waveforms for - operator: str + operator: str, default: "average" The operator to compute the templates Returns @@ -691,22 +710,23 @@ class ComputeNoiseLevels(AnalyzerExtension): need_recording = True use_nodepipeline = False need_job_kwargs = False + need_backward_compatibility_on_load = True def __init__(self, sorting_analyzer): AnalyzerExtension.__init__(self, sorting_analyzer) - def _set_params(self, num_chunks_per_segment=20, chunk_size=10000, seed=None): - params = dict(num_chunks_per_segment=num_chunks_per_segment, chunk_size=chunk_size, seed=seed) + def _set_params(self, **noise_level_params): + params = noise_level_params.copy() return params def _select_extension_data(self, unit_ids): - # this do not depend on units + # this does not depend on units return self.data def _merge_extension_data( self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs ): - # this do not depend on units + # this does not depend on units return self.data.copy() def _run(self, verbose=False): @@ -717,6 +737,15 @@ def _run(self, verbose=False): def _get_data(self): return self.data["noise_levels"] + def _handle_backward_compatibility_on_load(self): + # The old parameters used to be params=dict(num_chunks_per_segment=20, chunk_size=10000, seed=None) + # now it is handle more explicitly using random_slices_kwargs=dict() + for key in ("num_chunks_per_segment", "chunk_size", "seed"): + if key in self.params: + if "random_slices_kwargs" not in self.params: + self.params["random_slices_kwargs"] = dict() + self.params["random_slices_kwargs"][key] = self.params.pop(key) + register_result_extension(ComputeNoiseLevels) compute_noise_levels = ComputeNoiseLevels.function_factory() diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index 310533c96b..2ec3664a45 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -172,8 +172,10 @@ def _set_probes(self, probe_or_probegroup, group_mode="by_probe", in_place=False number_of_device_channel_indices = np.max(list(device_channel_indices) + [0]) if number_of_device_channel_indices >= self.get_num_channels(): error_msg = ( - f"The given Probe have 'device_channel_indices' that do not match channel count \n" - f"{number_of_device_channel_indices} vs {self.get_num_channels()} \n" + f"The given Probe either has 'device_channel_indices' that does not match channel count \n" + f"{len(device_channel_indices)} vs {self.get_num_channels()} \n" + f"or it's max index {number_of_device_channel_indices} is the same as the number of channels {self.get_num_channels()} \n" + f"If using all channels remember that python is 0-indexed so max device_channel_index should be {self.get_num_channels() - 1} \n" f"device_channel_indices are the following: {device_channel_indices} \n" f"recording channels are the following: {self.get_channel_ids()} \n" ) diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index 5240edcee7..7a6172369b 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -149,12 +149,12 @@ def divide_segment_into_chunks(num_frames, chunk_size): def divide_recording_into_chunks(recording, chunk_size): - all_chunks = [] + recording_slices = [] for segment_index in range(recording.get_num_segments()): num_frames = recording.get_num_samples(segment_index) chunks = divide_segment_into_chunks(num_frames, chunk_size) - all_chunks.extend([(segment_index, frame_start, frame_stop) for frame_start, frame_stop in chunks]) - return all_chunks + recording_slices.extend([(segment_index, frame_start, frame_stop) for frame_start, frame_stop in chunks]) + return recording_slices def ensure_n_jobs(recording, n_jobs=1): @@ -185,6 +185,22 @@ def ensure_n_jobs(recording, n_jobs=1): return n_jobs +def chunk_duration_to_chunk_size(chunk_duration, recording): + if isinstance(chunk_duration, float): + chunk_size = int(chunk_duration * recording.get_sampling_frequency()) + elif isinstance(chunk_duration, str): + if chunk_duration.endswith("ms"): + chunk_duration = float(chunk_duration.replace("ms", "")) / 1000.0 + elif chunk_duration.endswith("s"): + chunk_duration = float(chunk_duration.replace("s", "")) + else: + raise ValueError("chunk_duration must ends with s or ms") + chunk_size = int(chunk_duration * recording.get_sampling_frequency()) + else: + raise ValueError("chunk_duration must be str or float") + return chunk_size + + def ensure_chunk_size( recording, total_memory=None, chunk_size=None, chunk_memory=None, chunk_duration=None, n_jobs=1, **other_kwargs ): @@ -231,18 +247,7 @@ def ensure_chunk_size( num_channels = recording.get_num_channels() chunk_size = int(total_memory / (num_channels * n_bytes * n_jobs)) elif chunk_duration is not None: - if isinstance(chunk_duration, float): - chunk_size = int(chunk_duration * recording.get_sampling_frequency()) - elif isinstance(chunk_duration, str): - if chunk_duration.endswith("ms"): - chunk_duration = float(chunk_duration.replace("ms", "")) / 1000.0 - elif chunk_duration.endswith("s"): - chunk_duration = float(chunk_duration.replace("s", "")) - else: - raise ValueError("chunk_duration must ends with s or ms") - chunk_size = int(chunk_duration * recording.get_sampling_frequency()) - else: - raise ValueError("chunk_duration must be str or float") + chunk_size = chunk_duration_to_chunk_size(chunk_duration, recording) else: # Edge case to define single chunk per segment for n_jobs=1. # All chunking parameters equal None mean single chunk per segment @@ -382,11 +387,13 @@ def __init__( f"chunk_duration={chunk_duration_str}", ) - def run(self): + def run(self, recording_slices=None): """ Runs the defined jobs. """ - all_chunks = divide_recording_into_chunks(self.recording, self.chunk_size) + + if recording_slices is None: + recording_slices = divide_recording_into_chunks(self.recording, self.chunk_size) if self.handle_returns: returns = [] @@ -395,17 +402,17 @@ def run(self): if self.n_jobs == 1: if self.progress_bar: - all_chunks = tqdm(all_chunks, ascii=True, desc=self.job_name) + recording_slices = tqdm(recording_slices, ascii=True, desc=self.job_name) worker_ctx = self.init_func(*self.init_args) - for segment_index, frame_start, frame_stop in all_chunks: + for segment_index, frame_start, frame_stop in recording_slices: res = self.func(segment_index, frame_start, frame_stop, worker_ctx) if self.handle_returns: returns.append(res) if self.gather_func is not None: self.gather_func(res) else: - n_jobs = min(self.n_jobs, len(all_chunks)) + n_jobs = min(self.n_jobs, len(recording_slices)) # parallel with ProcessPoolExecutor( @@ -414,10 +421,10 @@ def run(self): mp_context=mp.get_context(self.mp_context), initargs=(self.func, self.init_func, self.init_args, self.max_threads_per_process), ) as executor: - results = executor.map(function_wrapper, all_chunks) + results = executor.map(function_wrapper, recording_slices) if self.progress_bar: - results = tqdm(results, desc=self.job_name, total=len(all_chunks)) + results = tqdm(results, desc=self.job_name, total=len(recording_slices)) for res in results: if self.handle_returns: diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index d90a20902d..53c2445c77 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -489,6 +489,7 @@ def run_node_pipeline( names=None, verbose=False, skip_after_n_peaks=None, + recording_slices=None, ): """ Machinery to compute in parallel operations on peaks and traces. @@ -540,6 +541,10 @@ def run_node_pipeline( skip_after_n_peaks : None | int Skip the computation after n_peaks. This is not an exact because internally this skip is done per worker in average. + recording_slices : None | list[tuple] + Optionaly give a list of slices to run the pipeline only on some chunks of the recording. + It must be a list of (segment_index, frame_start, frame_stop). + If None (default), the function iterates over the entire duration of the recording. Returns ------- @@ -578,7 +583,7 @@ def run_node_pipeline( **job_kwargs, ) - processor.run() + processor.run(recording_slices=recording_slices) outs = gather_func.finalize_buffers(squeeze_output=squeeze_output) return outs diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index 77d427bc88..4aabbfd587 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -18,6 +18,8 @@ fix_job_kwargs, ChunkRecordingExecutor, _shared_job_kwargs_doc, + chunk_duration_to_chunk_size, + split_job_kwargs, ) @@ -512,33 +514,38 @@ def determine_cast_unsigned(recording, dtype): return cast_unsigned -def get_random_data_chunks( +def get_random_recording_slices( recording, - return_scaled=False, + method="full_random", num_chunks_per_segment=20, - chunk_size=10000, - concatenated=True, - seed=0, + chunk_duration="500ms", + chunk_size=None, margin_frames=0, + seed=None, ): """ - Extract random chunks across segments + Get random slice of a recording across segments. - This is used for instance in get_noise_levels() to estimate noise on traces. + This is used for instance in get_noise_levels() and get_random_data_chunks() to estimate noise on traces. Parameters ---------- recording : BaseRecording The recording to get random chunks from - return_scaled : bool, default: False - If True, returned chunks are scaled to uV + method : "full_random" + The method used to get random slices. + * "full_random" : legacy method, used until version 0.101.0, there is no constrain on slices + and they can overlap. num_chunks_per_segment : int, default: 20 Number of chunks per segment - chunk_size : int, default: 10000 - Size of a chunk in number of frames + chunk_duration : str | float | None, default "500ms" + The duration of each chunk in 's' or 'ms' + chunk_size : int | None + Size of a chunk in number of frames. This is ued only if chunk_duration is None. + This is kept for backward compatibility, you should prefer 'chunk_duration=500ms' instead. concatenated : bool, default: True If True chunk are concatenated along time axis - seed : int, default: 0 + seed : int, default: None Random seed margin_frames : int, default: 0 Margin in number of frames to avoid edge effects @@ -547,42 +554,89 @@ def get_random_data_chunks( ------- chunk_list : np.array Array of concatenate chunks per segment + + """ # TODO: if segment have differents length make another sampling that dependant on the length of the segment # Should be done by changing kwargs with total_num_chunks=XXX and total_duration=YYYY # And randomize the number of chunk per segment weighted by segment duration - # check chunk size - num_segments = recording.get_num_segments() - for segment_index in range(num_segments): - chunk_size_limit = recording.get_num_frames(segment_index) - 2 * margin_frames - if chunk_size > chunk_size_limit: - chunk_size = chunk_size_limit - 1 - warnings.warn( - f"chunk_size is greater than the number " - f"of samples for segment index {segment_index}. " - f"Using {chunk_size}." - ) + if method == "full_random": + if chunk_size is None: + if chunk_duration is not None: + chunk_size = chunk_duration_to_chunk_size(chunk_duration, recording) + else: + raise ValueError("get_random_recording_slices need chunk_size or chunk_duration") + + # check chunk size + num_segments = recording.get_num_segments() + for segment_index in range(num_segments): + chunk_size_limit = recording.get_num_frames(segment_index) - 2 * margin_frames + if chunk_size > chunk_size_limit: + chunk_size = chunk_size_limit - 1 + warnings.warn( + f"chunk_size is greater than the number " + f"of samples for segment index {segment_index}. " + f"Using {chunk_size}." + ) + rng = np.random.default_rng(seed) + recording_slices = [] + low = margin_frames + size = num_chunks_per_segment + for segment_index in range(num_segments): + num_frames = recording.get_num_frames(segment_index) + high = num_frames - chunk_size - margin_frames + random_starts = rng.integers(low=low, high=high, size=size) + random_starts = np.sort(random_starts) + recording_slices += [ + (segment_index, start_frame, (start_frame + chunk_size)) for start_frame in random_starts + ] + else: + raise ValueError(f"get_random_recording_slices : wrong method {method}") - rng = np.random.default_rng(seed) - chunk_list = [] - low = margin_frames - size = num_chunks_per_segment - for segment_index in range(num_segments): - num_frames = recording.get_num_frames(segment_index) - high = num_frames - chunk_size - margin_frames - random_starts = rng.integers(low=low, high=high, size=size) - segment_trace_chunk = [ - recording.get_traces( - start_frame=start_frame, - end_frame=(start_frame + chunk_size), - segment_index=segment_index, - return_scaled=return_scaled, - ) - for start_frame in random_starts - ] + return recording_slices - chunk_list.extend(segment_trace_chunk) + +def get_random_data_chunks(recording, return_scaled=False, concatenated=True, **random_slices_kwargs): + """ + Extract random chunks across segments. + + Internally, it uses `get_random_recording_slices()` and retrieves the traces chunk as a list + or a concatenated unique array. + + Please read `get_random_recording_slices()` for more details on parameters. + + + Parameters + ---------- + recording : BaseRecording + The recording to get random chunks from + return_scaled : bool, default: False + If True, returned chunks are scaled to uV + num_chunks_per_segment : int, default: 20 + Number of chunks per segment + concatenated : bool, default: True + If True chunk are concatenated along time axis + **random_slices_kwargs : dict + Options transmited to get_random_recording_slices(), please read documentation from this + function for more details. + + Returns + ------- + chunk_list : np.array | list of np.array + Array of concatenate chunks per segment + """ + recording_slices = get_random_recording_slices(recording, **random_slices_kwargs) + + chunk_list = [] + for segment_index, start_frame, end_frame in recording_slices: + traces_chunk = recording.get_traces( + start_frame=start_frame, + end_frame=end_frame, + segment_index=segment_index, + return_scaled=return_scaled, + ) + chunk_list.append(traces_chunk) if concatenated: return np.concatenate(chunk_list, axis=0) @@ -637,19 +691,52 @@ def get_closest_channels(recording, channel_ids=None, num_channels=None): return np.array(closest_channels_inds), np.array(dists) +def _noise_level_chunk(segment_index, start_frame, end_frame, worker_ctx): + recording = worker_ctx["recording"] + + one_chunk = recording.get_traces( + start_frame=start_frame, + end_frame=end_frame, + segment_index=segment_index, + return_scaled=worker_ctx["return_scaled"], + ) + + if worker_ctx["method"] == "mad": + med = np.median(one_chunk, axis=0, keepdims=True) + # hard-coded so that core doesn't depend on scipy + noise_levels = np.median(np.abs(one_chunk - med), axis=0) / 0.6744897501960817 + elif worker_ctx["method"] == "std": + noise_levels = np.std(one_chunk, axis=0) + + return noise_levels + + +def _noise_level_chunk_init(recording, return_scaled, method): + worker_ctx = {} + worker_ctx["recording"] = recording + worker_ctx["return_scaled"] = return_scaled + worker_ctx["method"] = method + return worker_ctx + + def get_noise_levels( recording: "BaseRecording", return_scaled: bool = True, method: Literal["mad", "std"] = "mad", force_recompute: bool = False, - **random_chunk_kwargs, + random_slices_kwargs: dict = {}, + **kwargs, ) -> np.ndarray: """ Estimate noise for each channel using MAD methods. You can use standard deviation with `method="std"` Internally it samples some chunk across segment. - And then, it use MAD estimator (more robust than STD) + And then, it uses the MAD estimator (more robust than STD) or the STD on each chunk. + Finally the average of all MAD/STD values is performed. + + The result is cached in a property of the recording, so that the next call on the same + recording will use the cached result unless `force_recompute=True`. Parameters ---------- @@ -662,8 +749,11 @@ def get_noise_levels( The method to use to estimate noise levels force_recompute : bool If True, noise levels are recomputed even if they are already stored in the recording extractor - random_chunk_kwargs : dict - Kwargs for get_random_data_chunks + random_slices_kwargs : dict + Options transmited to get_random_recording_slices(), please read documentation from this + function for more details. + + {} Returns ------- @@ -679,19 +769,56 @@ def get_noise_levels( if key in recording.get_property_keys() and not force_recompute: noise_levels = recording.get_property(key=key) else: - random_chunks = get_random_data_chunks(recording, return_scaled=return_scaled, **random_chunk_kwargs) - - if method == "mad": - med = np.median(random_chunks, axis=0, keepdims=True) - # hard-coded so that core doesn't depend on scipy - noise_levels = np.median(np.abs(random_chunks - med), axis=0) / 0.6744897501960817 - elif method == "std": - noise_levels = np.std(random_chunks, axis=0) + # This is to keep backward compatibility + # lets keep for a while and remove this maybe in 0.103.0 + # chunk_size used to be in the signature and now is ambiguous + random_slices_kwargs_, job_kwargs = split_job_kwargs(kwargs) + if len(random_slices_kwargs_) > 0 or "chunk_size" in job_kwargs: + msg = ( + "get_noise_levels(recording, num_chunks_per_segment=20) is deprecated\n" + "Now, you need to use get_noise_levels(recording, random_slices_kwargs=dict(num_chunks_per_segment=20, chunk_size=1000))\n" + "Please read get_random_recording_slices() documentation for more options." + ) + # if the user use both the old and the new behavior then an error is raised + assert len(random_slices_kwargs) == 0, msg + warnings.warn(msg) + random_slices_kwargs = random_slices_kwargs_ + if "chunk_size" in job_kwargs: + random_slices_kwargs["chunk_size"] = job_kwargs["chunk_size"] + + recording_slices = get_random_recording_slices(recording, **random_slices_kwargs) + + noise_levels_chunks = [] + + def append_noise_chunk(res): + noise_levels_chunks.append(res) + + func = _noise_level_chunk + init_func = _noise_level_chunk_init + init_args = (recording, return_scaled, method) + executor = ChunkRecordingExecutor( + recording, + func, + init_func, + init_args, + job_name="noise_level", + verbose=False, + gather_func=append_noise_chunk, + **job_kwargs, + ) + executor.run(recording_slices=recording_slices) + noise_levels_chunks = np.stack(noise_levels_chunks) + noise_levels = np.mean(noise_levels_chunks, axis=0) + + # set property recording.set_property(key, noise_levels) return noise_levels +get_noise_levels.__doc__ = get_noise_levels.__doc__.format(_shared_job_kwargs_doc) + + def get_chunk_with_margin( rec_segment, start_frame, diff --git a/src/spikeinterface/core/template_tools.py b/src/spikeinterface/core/template_tools.py index 934b18ed49..3c8663df70 100644 --- a/src/spikeinterface/core/template_tools.py +++ b/src/spikeinterface/core/template_tools.py @@ -31,7 +31,12 @@ def get_dense_templates_array(one_object: Templates | SortingAnalyzer, return_sc ) ext = one_object.get_extension("templates") if ext is not None: - templates_array = ext.data["average"] + if "average" in ext.data: + templates_array = ext.data.get("average") + elif "median" in ext.data: + templates_array = ext.data.get("median") + else: + raise ValueError("Average or median templates have not been computed.") else: raise ValueError("SortingAnalyzer need extension 'templates' to be computed to retrieve templates") else: diff --git a/src/spikeinterface/core/tests/test_analyzer_extension_core.py b/src/spikeinterface/core/tests/test_analyzer_extension_core.py index 626899ab6e..6f5bef3c6c 100644 --- a/src/spikeinterface/core/tests/test_analyzer_extension_core.py +++ b/src/spikeinterface/core/tests/test_analyzer_extension_core.py @@ -2,6 +2,8 @@ import shutil +from pathlib import Path + from spikeinterface.core import generate_ground_truth_recording from spikeinterface.core import create_sorting_analyzer from spikeinterface.core import Templates @@ -250,16 +252,17 @@ def test_compute_several(create_cache_folder): if __name__ == "__main__": - - test_ComputeWaveforms(format="memory", sparse=True) - test_ComputeWaveforms(format="memory", sparse=False) - test_ComputeWaveforms(format="binary_folder", sparse=True) - test_ComputeWaveforms(format="binary_folder", sparse=False) - test_ComputeWaveforms(format="zarr", sparse=True) - test_ComputeWaveforms(format="zarr", sparse=False) - test_ComputeRandomSpikes(format="memory", sparse=True) - test_ComputeTemplates(format="memory", sparse=True) - test_ComputeNoiseLevels(format="memory", sparse=False) - - test_get_children_dependencies() - test_delete_on_recompute() + cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" / "core" + # test_ComputeWaveforms(format="memory", sparse=True, create_cache_folder=cache_folder) + # test_ComputeWaveforms(format="memory", sparse=False, create_cache_folder=cache_folder) + # test_ComputeWaveforms(format="binary_folder", sparse=True, create_cache_folder=cache_folder) + # test_ComputeWaveforms(format="binary_folder", sparse=False, create_cache_folder=cache_folder) + # test_ComputeWaveforms(format="zarr", sparse=True, create_cache_folder=cache_folder) + # test_ComputeWaveforms(format="zarr", sparse=False, create_cache_folder=cache_folder) + # test_ComputeRandomSpikes(format="memory", sparse=True, create_cache_folder=cache_folder) + test_ComputeRandomSpikes(format="binary_folder", sparse=False, create_cache_folder=cache_folder) + test_ComputeTemplates(format="memory", sparse=True, create_cache_folder=cache_folder) + test_ComputeNoiseLevels(format="memory", sparse=False, create_cache_folder=cache_folder) + + # test_get_children_dependencies() + # test_delete_on_recompute(cache_folder) diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index deef2291c6..028eaecf12 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -4,7 +4,7 @@ import shutil from spikeinterface import create_sorting_analyzer, get_template_extremum_channel, generate_ground_truth_recording - +from spikeinterface.core.job_tools import divide_recording_into_chunks # from spikeinterface.sortingcomponents.peak_detection import detect_peaks from spikeinterface.core.node_pipeline import ( @@ -191,8 +191,8 @@ def test_run_node_pipeline(cache_folder_creation): unpickled_node = pickle.loads(pickled_node) -def test_skip_after_n_peaks(): - recording, sorting = generate_ground_truth_recording(num_channels=10, num_units=10, durations=[10.0]) +def test_skip_after_n_peaks_and_recording_slices(): + recording, sorting = generate_ground_truth_recording(num_channels=10, num_units=10, durations=[10.0], seed=2205) # job_kwargs = dict(chunk_duration="0.5s", n_jobs=2, progress_bar=False) job_kwargs = dict(chunk_duration="0.5s", n_jobs=1, progress_bar=False) @@ -211,18 +211,27 @@ def test_skip_after_n_peaks(): node1 = AmplitudeExtractionNode(recording, parents=[node0], param0=6.6, return_output=True) nodes = [node0, node1] + # skip skip_after_n_peaks = 30 some_amplitudes = run_node_pipeline( recording, nodes, job_kwargs, gather_mode="memory", skip_after_n_peaks=skip_after_n_peaks ) - assert some_amplitudes.size >= skip_after_n_peaks assert some_amplitudes.size < spikes.size + # slices : 1 every 4 + recording_slices = divide_recording_into_chunks(recording, 10_000) + recording_slices = recording_slices[::4] + some_amplitudes = run_node_pipeline( + recording, nodes, job_kwargs, gather_mode="memory", recording_slices=recording_slices + ) + tolerance = 1.2 + assert some_amplitudes.size < (spikes.size // 4) * tolerance + # the following is for testing locally with python or ipython. It is not used in ci or with pytest. if __name__ == "__main__": # folder = Path("./cache_folder/core") # test_run_node_pipeline(folder) - test_skip_after_n_peaks() + test_skip_after_n_peaks_and_recording_slices() diff --git a/src/spikeinterface/core/tests/test_recording_tools.py b/src/spikeinterface/core/tests/test_recording_tools.py index 23a1574f2a..dad5273f12 100644 --- a/src/spikeinterface/core/tests/test_recording_tools.py +++ b/src/spikeinterface/core/tests/test_recording_tools.py @@ -11,6 +11,7 @@ from spikeinterface.core.recording_tools import ( write_binary_recording, write_memory_recording, + get_random_recording_slices, get_random_data_chunks, get_chunk_with_margin, get_closest_channels, @@ -167,6 +168,17 @@ def test_write_memory_recording(): shm.unlink() +def test_get_random_recording_slices(): + rec = generate_recording(num_channels=1, sampling_frequency=1000.0, durations=[10.0, 20.0]) + rec_slices = get_random_recording_slices( + rec, method="full_random", num_chunks_per_segment=20, chunk_duration="500ms", margin_frames=0, seed=0 + ) + assert len(rec_slices) == 40 + for seg_ind, start, stop in rec_slices: + assert stop - start == 500 + assert seg_ind in (0, 1) + + def test_get_random_data_chunks(): rec = generate_recording(num_channels=1, sampling_frequency=1000.0, durations=[10.0, 20.0]) chunks = get_random_data_chunks(rec, num_chunks_per_segment=50, chunk_size=500, seed=0) @@ -182,16 +194,17 @@ def test_get_closest_channels(): def test_get_noise_levels(): + job_kwargs = dict(n_jobs=1, progress_bar=True) rec = generate_recording(num_channels=2, sampling_frequency=1000.0, durations=[60.0]) - noise_levels_1 = get_noise_levels(rec, return_scaled=False) - noise_levels_2 = get_noise_levels(rec, return_scaled=False) + noise_levels_1 = get_noise_levels(rec, return_scaled=False, **job_kwargs) + noise_levels_2 = get_noise_levels(rec, return_scaled=False, **job_kwargs) rec.set_channel_gains(0.1) rec.set_channel_offsets(0) - noise_levels = get_noise_levels(rec, return_scaled=True, force_recompute=True) + noise_levels = get_noise_levels(rec, return_scaled=True, force_recompute=True, **job_kwargs) - noise_levels = get_noise_levels(rec, return_scaled=True, method="std") + noise_levels = get_noise_levels(rec, return_scaled=True, method="std", **job_kwargs) # Generate a recording following a gaussian distribution to check the result of get_noise. std = 6.0 @@ -201,8 +214,10 @@ def test_get_noise_levels(): recording = NumpyRecording(traces, 30000) assert np.all(noise_levels_1 == noise_levels_2) - assert np.allclose(get_noise_levels(recording, return_scaled=False), [std, std], rtol=1e-2, atol=1e-3) - assert np.allclose(get_noise_levels(recording, method="std", return_scaled=False), [std, std], rtol=1e-2, atol=1e-3) + assert np.allclose(get_noise_levels(recording, return_scaled=False, **job_kwargs), [std, std], rtol=1e-2, atol=1e-3) + assert np.allclose( + get_noise_levels(recording, method="std", return_scaled=False, **job_kwargs), [std, std], rtol=1e-2, atol=1e-3 + ) def test_get_noise_levels_output(): @@ -216,10 +231,21 @@ def test_get_noise_levels_output(): traces = rng.normal(loc=10.0, scale=std, size=(num_samples, num_channels)) recording = NumpyRecording(traces_list=traces, sampling_frequency=sampling_frequency) - std_estimated_with_mad = get_noise_levels(recording, method="mad", return_scaled=False, chunk_size=1_000) + std_estimated_with_mad = get_noise_levels( + recording, + method="mad", + return_scaled=False, + random_slices_kwargs=dict(num_chunks_per_segment=40, chunk_size=1_000, seed=seed), + ) + print(std_estimated_with_mad) assert np.allclose(std_estimated_with_mad, [std, std], rtol=1e-2, atol=1e-3) - std_estimated_with_std = get_noise_levels(recording, method="std", return_scaled=False, chunk_size=1_000) + std_estimated_with_std = get_noise_levels( + recording, + method="std", + return_scaled=False, + random_slices_kwargs=dict(num_chunks_per_segment=40, chunk_size=1_000, seed=seed), + ) assert np.allclose(std_estimated_with_std, [std, std], rtol=1e-2, atol=1e-3) @@ -333,14 +359,16 @@ def test_do_recording_attributes_match(): if __name__ == "__main__": # Create a temporary folder using the standard library - import tempfile - - with tempfile.TemporaryDirectory() as tmpdirname: - tmp_path = Path(tmpdirname) - test_write_binary_recording(tmp_path) - test_write_memory_recording() - - test_get_random_data_chunks() - test_get_closest_channels() - test_get_noise_levels() - test_order_channels_by_depth() + # import tempfile + + # with tempfile.TemporaryDirectory() as tmpdirname: + # tmp_path = Path(tmpdirname) + # test_write_binary_recording(tmp_path) + # test_write_memory_recording() + + test_get_random_recording_slices() + # test_get_random_data_chunks() + # test_get_closest_channels() + # test_get_noise_levels() + # test_get_noise_levels_output() + # test_order_channels_by_depth() diff --git a/src/spikeinterface/curation/__init__.py b/src/spikeinterface/curation/__init__.py index 657b936fb9..0302ffe5b7 100644 --- a/src/spikeinterface/curation/__init__.py +++ b/src/spikeinterface/curation/__init__.py @@ -3,7 +3,7 @@ from .remove_redundant import remove_redundant_units, find_redundant_units from .remove_duplicated_spikes import remove_duplicated_spikes from .remove_excess_spikes import remove_excess_spikes -from .auto_merge import get_potential_auto_merge +from .auto_merge import compute_merge_unit_groups, auto_merge_units, get_potential_auto_merge # manual sorting, diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 19336e5943..4f4cff144e 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -1,5 +1,7 @@ from __future__ import annotations +import warnings + from typing import Tuple import numpy as np import math @@ -12,49 +14,90 @@ HAVE_NUMBA = False from ..core import SortingAnalyzer, Templates -from ..core.template_tools import get_template_extremum_channel -from ..postprocessing import compute_correlograms from ..qualitymetrics import compute_refrac_period_violations, compute_firing_rates from .mergeunitssorting import MergeUnitsSorting from .curation_tools import resolve_merging_graph - -_possible_presets = ["similarity_correlograms", "x_contaminations", "temporal_splits", "feature_neighbors"] +_compute_merge_presets = { + "similarity_correlograms": [ + "num_spikes", + "remove_contaminated", + "unit_locations", + "template_similarity", + "correlogram", + "quality_score", + ], + "temporal_splits": [ + "num_spikes", + "remove_contaminated", + "unit_locations", + "template_similarity", + "presence_distance", + "quality_score", + ], + "x_contaminations": [ + "num_spikes", + "remove_contaminated", + "unit_locations", + "template_similarity", + "cross_contamination", + "quality_score", + ], + "feature_neighbors": [ + "num_spikes", + "snr", + "remove_contaminated", + "unit_locations", + "knn", + "quality_score", + ], +} _required_extensions = { - "unit_locations": ["unit_locations"], + "unit_locations": ["templates", "unit_locations"], "correlogram": ["correlograms"], - "template_similarity": ["template_similarity"], - "knn": ["spike_locations", "spike_amplitudes"], + "snr": ["templates", "noise_levels"], + "template_similarity": ["templates", "template_similarity"], + "knn": ["templates", "spike_locations", "spike_amplitudes"], } -def get_potential_auto_merge( +_default_step_params = { + "num_spikes": {"min_spikes": 100}, + "snr": {"min_snr": 2}, + "remove_contaminated": {"contamination_thresh": 0.2, "refractory_period_ms": 1.0, "censored_period_ms": 0.3}, + "unit_locations": {"max_distance_um": 150}, + "correlogram": { + "corr_diff_thresh": 0.16, + "censor_correlograms_ms": 0.15, + "sigma_smooth_ms": 0.6, + "adaptative_window_thresh": 0.5, + }, + "template_similarity": {"template_diff_thresh": 0.25}, + "presence_distance": {"presence_distance_thresh": 100}, + "knn": {"k_nn": 10}, + "cross_contamination": { + "cc_thresh": 0.1, + "p_value": 0.2, + "refractory_period_ms": 1.0, + "censored_period_ms": 0.3, + }, + "quality_score": {"firing_contamination_balance": 1.5, "refractory_period_ms": 1.0, "censored_period_ms": 0.3}, +} + + +def compute_merge_unit_groups( sorting_analyzer: SortingAnalyzer, preset: str | None = "similarity_correlograms", - resolve_graph: bool = False, - min_spikes: int = 100, - min_snr: float = 2, - max_distance_um: float = 150.0, - corr_diff_thresh: float = 0.16, - template_diff_thresh: float = 0.25, - contamination_thresh: float = 0.2, - presence_distance_thresh: float = 100, - p_value: float = 0.2, - cc_thresh: float = 0.1, - censored_period_ms: float = 0.3, - refractory_period_ms: float = 1.0, - sigma_smooth_ms: float = 0.6, - adaptative_window_thresh: float = 0.5, - censor_correlograms_ms: float = 0.15, - firing_contamination_balance: float = 2.5, - k_nn: int = 10, - knn_kwargs: dict | None = None, - presence_distance_kwargs: dict | None = None, + resolve_graph: bool = True, + steps_params: dict = None, + compute_needed_extensions: bool = True, extra_outputs: bool = False, steps: list[str] | None = None, -) -> list[tuple[int | str, int | str]] | Tuple[tuple[int | str, int | str], dict]: + force_copy: bool = True, + **job_kwargs, +) -> list[tuple[int | str, int | str]] | Tuple[list[tuple[int | str, int | str]], dict]: """ Algorithm to find and check potential merges between units. @@ -78,6 +121,9 @@ def get_potential_auto_merge( Q = f(1 - (k + 1)C) + IMPORTANT: internally, all computations are relying on extensions of the analyzer, that are computed + with default parameters if not present (i.e. correlograms, template_similarity, ...) If you want to + have a finer control on these values, please precompute the extensions before applying the auto_merge Parameters ---------- @@ -98,47 +144,11 @@ def get_potential_auto_merge( * | "feature_neighbors": focused on finding unit pairs whose spikes are close in the feature space using kNN. | It uses the following steps: "num_spikes", "snr", "remove_contaminated", "unit_locations", | "knn", "quality_score" - If `preset` is None, you can specify the steps manually with the `steps` parameter. - resolve_graph : bool, default: False + resolve_graph : bool, default: True If True, the function resolves the potential unit pairs to be merged into multiple-unit merges. - min_spikes : int, default: 100 - Minimum number of spikes for each unit to consider a potential merge. - Enough spikes are needed to estimate the correlogram - min_snr : float, default 2 - Minimum Signal to Noise ratio for templates to be considered while merging - max_distance_um : float, default: 150 - Maximum distance between units for considering a merge - corr_diff_thresh : float, default: 0.16 - The threshold on the "correlogram distance metric" for considering a merge. - It needs to be between 0 and 1 - template_diff_thresh : float, default: 0.25 - The threshold on the "template distance metric" for considering a merge. - It needs to be between 0 and 1 - contamination_thresh : float, default: 0.2 - Threshold for not taking in account a unit when it is too contaminated. - presence_distance_thresh : float, default: 100 - Parameter to control how present two units should be simultaneously. - p_value : float, default: 0.2 - The p-value threshold for the cross-contamination test. - cc_thresh : float, default: 0.1 - The threshold on the cross-contamination for considering a merge. - censored_period_ms : float, default: 0.3 - Used to compute the refractory period violations aka "contamination". - refractory_period_ms : float, default: 1 - Used to compute the refractory period violations aka "contamination". - sigma_smooth_ms : float, default: 0.6 - Parameters to smooth the correlogram estimation. - adaptative_window_thresh : float, default: 0.5 - Parameter to detect the window size in correlogram estimation. - censor_correlograms_ms : float, default: 0.15 - The period to censor on the auto and cross-correlograms. - firing_contamination_balance : float, default: 2.5 - Parameter to control the balance between firing rate and contamination in computing unit "quality score". - k_nn : int, default 5 - The number of neighbors to consider for every spike in the recording. - knn_kwargs : dict, default None - The dict of extra params to be passed to knn. + compute_needed_extensions : bool, default : True + Should we force the computation of needed extensions, if not already computed? extra_outputs : bool, default: False If True, an additional dictionary (`outs`) with processed data is returned. steps : None or list of str, default: None @@ -146,157 +156,141 @@ def get_potential_auto_merge( Pontential steps : "num_spikes", "snr", "remove_contaminated", "unit_locations", "correlogram", "template_similarity", "presence_distance", "cross_contamination", "knn", "quality_score" Please check steps explanations above! - presence_distance_kwargs : None|dict, default: None - A dictionary of kwargs to be passed to compute_presence_distance(). + steps_params : dict + A dictionary whose keys are the steps, and keys are steps parameters. + force_copy : boolean, default: True + When new extensions are computed, the default is to make a copy of the analyzer, to avoid overwriting + already computed extensions. False if you want to overwrite Returns ------- - potential_merges: - A list of tuples of 2 elements (if `resolve_graph`if false) or 2+ elements (if `resolve_graph` is true). - List of pairs that could be merged. + merge_unit_groups: + List of groups that need to be merge. + When `resolve_graph` is true (default) a list of tuples of 2+ elements + If `resolve_graph` is false then a list of tuple of 2 elements is returned instead. outs: Returned only when extra_outputs=True A dictionary that contains data for debugging and plotting. References ---------- - This function is inspired and built upon similar functions from Lussac [Llobet]_, + This function used to be inspired and built upon similar functions from Lussac [Llobet]_, done by Aurelien Wyngaard and Victor Llobet. https://github.com/BarbourLab/lussac/blob/v1.0.0/postprocessing/merge_units.py + + However, it has been greatly consolidated and refined depending on the presets. """ import scipy sorting = sorting_analyzer.sorting unit_ids = sorting.unit_ids - # to get fast computation we will not analyse pairs when: - # * not enough spikes for one of theses - # * auto correlogram is contaminated - # * to far away one from each other - - all_steps = [ - "num_spikes", - "snr", - "remove_contaminated", - "unit_locations", - "correlogram", - "template_similarity", - "presence_distance", - "knn", - "cross_contamination", - "quality_score", - ] - - if preset is not None and preset not in _possible_presets: - raise ValueError(f"preset must be one of {_possible_presets}") - - if steps is None: - if preset is None: - if steps is None: - raise ValueError("You need to specify a preset or steps for the auto-merge function") - elif preset == "similarity_correlograms": - steps = [ - "num_spikes", - "remove_contaminated", - "unit_locations", - "template_similarity", - "correlogram", - "quality_score", - ] - elif preset == "temporal_splits": - steps = [ - "num_spikes", - "remove_contaminated", - "unit_locations", - "template_similarity", - "presence_distance", - "quality_score", - ] - elif preset == "x_contaminations": - steps = [ - "num_spikes", - "remove_contaminated", - "unit_locations", - "template_similarity", - "cross_contamination", - "quality_score", - ] - elif preset == "feature_neighbors": - steps = [ - "num_spikes", - "snr", - "remove_contaminated", - "unit_locations", - "knn", - "quality_score", - ] - + if preset is None and steps is None: + raise ValueError("You need to specify a preset or steps for the auto-merge function") + elif steps is not None: + # steps has precedence on presets + pass + elif preset is not None: + if preset not in _compute_merge_presets: + raise ValueError(f"preset must be one of {list(_compute_merge_presets.keys())}") + steps = _compute_merge_presets[preset] + + # check at least one extension is needed + at_least_one_extension_to_compute = False for step in steps: + assert step in _default_step_params, f"{step} is not a valid step" if step in _required_extensions: for ext in _required_extensions[step]: - if not sorting_analyzer.has_extension(ext): + if sorting_analyzer.has_extension(ext): + continue + if not compute_needed_extensions: raise ValueError(f"{step} requires {ext} extension") + at_least_one_extension_to_compute = True + + if force_copy and at_least_one_extension_to_compute: + # To avoid erasing the extensions of the user + sorting_analyzer = sorting_analyzer.copy() n = unit_ids.size - pair_mask = np.triu(np.arange(n)) > 0 + pair_mask = np.triu(np.arange(n), 1) > 0 outs = dict() for step in steps: - assert step in all_steps, f"{step} is not a valid step" + if step in _required_extensions: + for ext in _required_extensions[step]: + if sorting_analyzer.has_extension(ext): + continue + + # special case for templates + if ext == "templates" and not sorting_analyzer.has_extension("random_spikes"): + sorting_analyzer.compute(["random_spikes", "templates"], **job_kwargs) + else: + sorting_analyzer.compute(ext, **job_kwargs) + + params = _default_step_params.get(step).copy() + if steps_params is not None and step in steps_params: + params.update(steps_params[step]) # STEP : remove units with too few spikes if step == "num_spikes": + num_spikes = sorting.count_num_spikes_per_unit(outputs="array") - to_remove = num_spikes < min_spikes + to_remove = num_spikes < params["min_spikes"] pair_mask[to_remove, :] = False pair_mask[:, to_remove] = False + outs["num_spikes"] = to_remove # STEP : remove units with too small SNR elif step == "snr": qm_ext = sorting_analyzer.get_extension("quality_metrics") if qm_ext is None: - sorting_analyzer.compute("noise_levels") - sorting_analyzer.compute("quality_metrics", metric_names=["snr"]) + sorting_analyzer.compute("quality_metrics", metric_names=["snr"], **job_kwargs) qm_ext = sorting_analyzer.get_extension("quality_metrics") snrs = qm_ext.get_data()["snr"].values - to_remove = snrs < min_snr + to_remove = snrs < params["min_snr"] pair_mask[to_remove, :] = False pair_mask[:, to_remove] = False + outs["snr"] = to_remove # STEP : remove contaminated auto corr elif step == "remove_contaminated": contaminations, nb_violations = compute_refrac_period_violations( - sorting_analyzer, refractory_period_ms=refractory_period_ms, censored_period_ms=censored_period_ms + sorting_analyzer, + refractory_period_ms=params["refractory_period_ms"], + censored_period_ms=params["censored_period_ms"], ) nb_violations = np.array(list(nb_violations.values())) contaminations = np.array(list(contaminations.values())) - to_remove = contaminations > contamination_thresh + to_remove = contaminations > params["contamination_thresh"] pair_mask[to_remove, :] = False pair_mask[:, to_remove] = False + outs["remove_contaminated"] = to_remove # STEP : unit positions are estimated roughly with channel - elif step == "unit_locations" in steps: + elif step == "unit_locations": location_ext = sorting_analyzer.get_extension("unit_locations") unit_locations = location_ext.get_data()[:, :2] unit_distances = scipy.spatial.distance.cdist(unit_locations, unit_locations, metric="euclidean") - pair_mask = pair_mask & (unit_distances <= max_distance_um) + pair_mask = pair_mask & (unit_distances <= params["max_distance_um"]) outs["unit_distances"] = unit_distances # STEP : potential auto merge by correlogram - elif step == "correlogram" in steps: + elif step == "correlogram": correlograms_ext = sorting_analyzer.get_extension("correlograms") correlograms, bins = correlograms_ext.get_data() - mask = (bins[:-1] >= -censor_correlograms_ms) & (bins[:-1] < censor_correlograms_ms) + censor_ms = params["censor_correlograms_ms"] + sigma_smooth_ms = params["sigma_smooth_ms"] + mask = (bins[:-1] >= -censor_ms) & (bins[:-1] < censor_ms) correlograms[:, :, mask] = 0 correlograms_smoothed = smooth_correlogram(correlograms, bins, sigma_smooth_ms=sigma_smooth_ms) # find correlogram window for each units win_sizes = np.zeros(n, dtype=int) for unit_ind in range(n): auto_corr = correlograms_smoothed[unit_ind, unit_ind, :] - thresh = np.max(auto_corr) * adaptative_window_thresh + thresh = np.max(auto_corr) * params["adaptative_window_thresh"] win_size = get_unit_adaptive_window(auto_corr, thresh) win_sizes[unit_ind] = win_size correlogram_diff = compute_correlogram_diff( @@ -306,7 +300,7 @@ def get_potential_auto_merge( pair_mask=pair_mask, ) # print(correlogram_diff) - pair_mask = pair_mask & (correlogram_diff < corr_diff_thresh) + pair_mask = pair_mask & (correlogram_diff < params["corr_diff_thresh"]) outs["correlograms"] = correlograms outs["bins"] = bins outs["correlograms_smoothed"] = correlograms_smoothed @@ -314,22 +308,21 @@ def get_potential_auto_merge( outs["win_sizes"] = win_sizes # STEP : check if potential merge with CC also have template similarity - elif step == "template_similarity" in steps: + elif step == "template_similarity": template_similarity_ext = sorting_analyzer.get_extension("template_similarity") templates_similarity = template_similarity_ext.get_data() templates_diff = 1 - templates_similarity - pair_mask = pair_mask & (templates_diff < template_diff_thresh) + pair_mask = pair_mask & (templates_diff < params["template_diff_thresh"]) outs["templates_diff"] = templates_diff # STEP : check the vicinity of the spikes - elif step == "knn" in steps: - if knn_kwargs is None: - knn_kwargs = dict() - pair_mask = get_pairs_via_nntree(sorting_analyzer, k_nn, pair_mask, **knn_kwargs) + elif step == "knn": + pair_mask = get_pairs_via_nntree(sorting_analyzer, **params, pair_mask=pair_mask) # STEP : check how the rates overlap in times - elif step == "presence_distance" in steps: - presence_distance_kwargs = presence_distance_kwargs or dict() + elif step == "presence_distance": + presence_distance_kwargs = params.copy() + presence_distance_thresh = presence_distance_kwargs.pop("presence_distance_thresh") num_samples = [ sorting_analyzer.get_num_samples(segment_index) for segment_index in range(sorting.get_num_segments()) ] @@ -340,40 +333,243 @@ def get_potential_auto_merge( outs["presence_distances"] = presence_distances # STEP : check if the cross contamination is significant - elif step == "cross_contamination" in steps: - refractory = (censored_period_ms, refractory_period_ms) + elif step == "cross_contamination": + refractory = ( + params["censored_period_ms"], + params["refractory_period_ms"], + ) CC, p_values = compute_cross_contaminations( - sorting_analyzer, pair_mask, cc_thresh, refractory, contaminations + sorting_analyzer, pair_mask, params["cc_thresh"], refractory, contaminations ) - pair_mask = pair_mask & (p_values > p_value) + pair_mask = pair_mask & (p_values > params["p_value"]) outs["cross_contaminations"] = CC, p_values # STEP : validate the potential merges with CC increase the contamination quality metrics - elif step == "quality_score" in steps: + elif step == "quality_score": pair_mask, pairs_decreased_score = check_improve_contaminations_score( sorting_analyzer, pair_mask, contaminations, - firing_contamination_balance, - refractory_period_ms, - censored_period_ms, + params["firing_contamination_balance"], + params["refractory_period_ms"], + params["censored_period_ms"], ) outs["pairs_decreased_score"] = pairs_decreased_score # FINAL STEP : create the final list from pair_mask boolean matrix ind1, ind2 = np.nonzero(pair_mask) - potential_merges = list(zip(unit_ids[ind1], unit_ids[ind2])) - - # some methods return identities ie (1,1) which we can cleanup first. - potential_merges = [(ids[0], ids[1]) for ids in potential_merges if ids[0] != ids[1]] + merge_unit_groups = list(zip(unit_ids[ind1], unit_ids[ind2])) if resolve_graph: - potential_merges = resolve_merging_graph(sorting, potential_merges) + merge_unit_groups = resolve_merging_graph(sorting, merge_unit_groups) if extra_outputs: - return potential_merges, outs + return merge_unit_groups, outs else: - return potential_merges + return merge_unit_groups + + +def auto_merge_units( + sorting_analyzer: SortingAnalyzer, compute_merge_kwargs: dict = {}, apply_merge_kwargs: dict = {}, **job_kwargs +) -> SortingAnalyzer: + """ + Compute merge unit groups and apply it on a SortingAnalyzer. + Internally uses `compute_merge_unit_groups()` + """ + merge_unit_groups = compute_merge_unit_groups( + sorting_analyzer, extra_outputs=False, **compute_merge_kwargs, **job_kwargs + ) + + merged_analyzer = sorting_analyzer.merge_units(merge_unit_groups, **apply_merge_kwargs, **job_kwargs) + return merged_analyzer + + +def get_potential_auto_merge( + sorting_analyzer: SortingAnalyzer, + preset: str | None = "similarity_correlograms", + resolve_graph: bool = False, + min_spikes: int = 100, + min_snr: float = 2, + max_distance_um: float = 150.0, + corr_diff_thresh: float = 0.16, + template_diff_thresh: float = 0.25, + contamination_thresh: float = 0.2, + presence_distance_thresh: float = 100, + p_value: float = 0.2, + cc_thresh: float = 0.1, + censored_period_ms: float = 0.3, + refractory_period_ms: float = 1.0, + sigma_smooth_ms: float = 0.6, + adaptative_window_thresh: float = 0.5, + censor_correlograms_ms: float = 0.15, + firing_contamination_balance: float = 1.5, + k_nn: int = 10, + knn_kwargs: dict | None = None, + presence_distance_kwargs: dict | None = None, + extra_outputs: bool = False, + steps: list[str] | None = None, +) -> list[tuple[int | str, int | str]] | Tuple[tuple[int | str, int | str], dict]: + """ + This function is deprecated. Use compute_merge_unit_groups() instead. + This will be removed in 0.103.0 + + Algorithm to find and check potential merges between units. + + The merges are proposed based on a series of steps with different criteria: + + * "num_spikes": enough spikes are found in each unit for computing the correlogram (`min_spikes`) + * "snr": the SNR of the units is above a threshold (`min_snr`) + * "remove_contaminated": each unit is not contaminated (by checking auto-correlogram - `contamination_thresh`) + * "unit_locations": estimated unit locations are close enough (`max_distance_um`) + * "correlogram": the cross-correlograms of the two units are similar to each auto-corrleogram (`corr_diff_thresh`) + * "template_similarity": the templates of the two units are similar (`template_diff_thresh`) + * "presence_distance": the presence of the units is complementary in time (`presence_distance_thresh`) + * "cross_contamination": the cross-contamination is not significant (`cc_thresh` and `p_value`) + * "knn": the two units are close in the feature space + * "quality_score": the unit "quality score" is increased after the merge + + The "quality score" factors in the increase in firing rate (**f**) due to the merge and a possible increase in + contamination (**C**), wheighted by a factor **k** (`firing_contamination_balance`). + + .. math:: + + Q = f(1 - (k + 1)C) + + IMPORTANT: internally, all computations are relying on extensions of the analyzer, that are computed + with default parameters if not present (i.e. correlograms, template_similarity, ...) If you want to + have a finer control on these values, please precompute the extensions before applying the auto_merge + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + The SortingAnalyzer + preset : "similarity_correlograms" | "x_contaminations" | "temporal_splits" | "feature_neighbors" | None, default: "similarity_correlograms" + The preset to use for the auto-merge. Presets combine different steps into a recipe and focus on: + + * | "similarity_correlograms": mainly focused on template similarity and correlograms. + | It uses the following steps: "num_spikes", "remove_contaminated", "unit_locations", + | "template_similarity", "correlogram", "quality_score" + * | "x_contaminations": similar to "similarity_correlograms", but checks for cross-contamination instead of correlograms. + | It uses the following steps: "num_spikes", "remove_contaminated", "unit_locations", + | "template_similarity", "cross_contamination", "quality_score" + * | "temporal_splits": focused on finding temporal splits using presence distance. + | It uses the following steps: "num_spikes", "remove_contaminated", "unit_locations", + | "template_similarity", "presence_distance", "quality_score" + * | "feature_neighbors": focused on finding unit pairs whose spikes are close in the feature space using kNN. + | It uses the following steps: "num_spikes", "snr", "remove_contaminated", "unit_locations", + | "knn", "quality_score" + + If `preset` is None, you can specify the steps manually with the `steps` parameter. + resolve_graph : bool, default: False + If True, the function resolves the potential unit pairs to be merged into multiple-unit merges. + min_spikes : int, default: 100 + Minimum number of spikes for each unit to consider a potential merge. + Enough spikes are needed to estimate the correlogram + min_snr : float, default 2 + Minimum Signal to Noise ratio for templates to be considered while merging + max_distance_um : float, default: 150 + Maximum distance between units for considering a merge + corr_diff_thresh : float, default: 0.16 + The threshold on the "correlogram distance metric" for considering a merge. + It needs to be between 0 and 1 + template_diff_thresh : float, default: 0.25 + The threshold on the "template distance metric" for considering a merge. + It needs to be between 0 and 1 + contamination_thresh : float, default: 0.2 + Threshold for not taking in account a unit when it is too contaminated. + presence_distance_thresh : float, default: 100 + Parameter to control how present two units should be simultaneously. + p_value : float, default: 0.2 + The p-value threshold for the cross-contamination test. + cc_thresh : float, default: 0.1 + The threshold on the cross-contamination for considering a merge. + censored_period_ms : float, default: 0.3 + Used to compute the refractory period violations aka "contamination". + refractory_period_ms : float, default: 1 + Used to compute the refractory period violations aka "contamination". + sigma_smooth_ms : float, default: 0.6 + Parameters to smooth the correlogram estimation. + adaptative_window_thresh : float, default: 0.5 + Parameter to detect the window size in correlogram estimation. + censor_correlograms_ms : float, default: 0.15 + The period to censor on the auto and cross-correlograms. + firing_contamination_balance : float, default: 1.5 + Parameter to control the balance between firing rate and contamination in computing unit "quality score". + k_nn : int, default 5 + The number of neighbors to consider for every spike in the recording. + knn_kwargs : dict, default None + The dict of extra params to be passed to knn. + extra_outputs : bool, default: False + If True, an additional dictionary (`outs`) with processed data is returned. + steps : None or list of str, default: None + Which steps to run, if no preset is used. + Pontential steps : "num_spikes", "snr", "remove_contaminated", "unit_locations", "correlogram", + "template_similarity", "presence_distance", "cross_contamination", "knn", "quality_score" + Please check steps explanations above! + presence_distance_kwargs : None|dict, default: None + A dictionary of kwargs to be passed to compute_presence_distance(). + + Returns + ------- + potential_merges: + A list of tuples of 2 elements (if `resolve_graph`if false) or 2+ elements (if `resolve_graph` is true). + List of pairs that could be merged. + outs: + Returned only when extra_outputs=True + A dictionary that contains data for debugging and plotting. + + References + ---------- + This function is inspired and built upon similar functions from Lussac [Llobet]_, + done by Aurelien Wyngaard and Victor Llobet. + https://github.com/BarbourLab/lussac/blob/v1.0.0/postprocessing/merge_units.py + """ + warnings.warn( + "get_potential_auto_merge() is deprecated. Use compute_merge_unit_groups() instead", + DeprecationWarning, + stacklevel=2, + ) + + presence_distance_kwargs = presence_distance_kwargs or dict() + knn_kwargs = knn_kwargs or dict() + return compute_merge_unit_groups( + sorting_analyzer, + preset, + resolve_graph, + steps_params={ + "num_spikes": {"min_spikes": min_spikes}, + "snr": {"min_snr": min_snr}, + "remove_contaminated": { + "contamination_thresh": contamination_thresh, + "refractory_period_ms": refractory_period_ms, + "censored_period_ms": censored_period_ms, + }, + "unit_locations": {"max_distance_um": max_distance_um}, + "correlogram": { + "corr_diff_thresh": corr_diff_thresh, + "censor_correlograms_ms": censor_correlograms_ms, + "sigma_smooth_ms": sigma_smooth_ms, + "adaptative_window_thresh": adaptative_window_thresh, + }, + "template_similarity": {"template_diff_thresh": template_diff_thresh}, + "presence_distance": {"presence_distance_thresh": presence_distance_thresh, **presence_distance_kwargs}, + "knn": {"k_nn": k_nn, **knn_kwargs}, + "cross_contamination": { + "cc_thresh": cc_thresh, + "p_value": p_value, + "refractory_period_ms": refractory_period_ms, + "censored_period_ms": censored_period_ms, + }, + "quality_score": { + "firing_contamination_balance": firing_contamination_balance, + "refractory_period_ms": refractory_period_ms, + "censored_period_ms": censored_period_ms, + }, + }, + compute_needed_extensions=True, + extra_outputs=extra_outputs, + steps=steps, + ) def get_pairs_via_nntree(sorting_analyzer, k_nn=5, pair_mask=None, **knn_kwargs): @@ -661,10 +857,10 @@ def check_improve_contaminations_score( f_new = compute_firing_rates(sorting_analyzer_new)[unit_id1] # old and new scores - k = firing_contamination_balance - score_1 = f_1 * (1 - (k + 1) * c_1) - score_2 = f_2 * (1 - (k + 1) * c_2) - score_new = f_new * (1 - (k + 1) * c_new) + k = 1 + firing_contamination_balance + score_1 = f_1 * (1 - k * c_1) + score_2 = f_2 * (1 - k * c_2) + score_new = f_new * (1 - k * c_new) if score_new < score_1 or score_new < score_2: # the score is not improved diff --git a/src/spikeinterface/curation/tests/test_auto_merge.py b/src/spikeinterface/curation/tests/test_auto_merge.py index 33fd06d27a..4c05f41a4c 100644 --- a/src/spikeinterface/curation/tests/test_auto_merge.py +++ b/src/spikeinterface/curation/tests/test_auto_merge.py @@ -3,16 +3,16 @@ from spikeinterface.core import create_sorting_analyzer from spikeinterface.core.generate import inject_some_split_units -from spikeinterface.curation import get_potential_auto_merge +from spikeinterface.curation import compute_merge_unit_groups, auto_merge from spikeinterface.curation.tests.common import make_sorting_analyzer, sorting_analyzer_for_curation @pytest.mark.parametrize( - "preset", ["x_contaminations", "feature_neighbors", "temporal_splits", "similarity_correlograms"] + "preset", ["x_contaminations", "feature_neighbors", "temporal_splits", "similarity_correlograms", None] ) -def test_get_auto_merge_list(sorting_analyzer_for_curation, preset): +def test_compute_merge_unit_groups(sorting_analyzer_for_curation, preset): print(sorting_analyzer_for_curation) sorting = sorting_analyzer_for_curation.sorting @@ -47,32 +47,38 @@ def test_get_auto_merge_list(sorting_analyzer_for_curation, preset): ) if preset is not None: - potential_merges, outs = get_potential_auto_merge( + # do not resolve graph for checking true pairs + merge_unit_groups, outs = compute_merge_unit_groups( sorting_analyzer, preset=preset, - min_spikes=1000, - max_distance_um=150.0, - contamination_thresh=0.2, - corr_diff_thresh=0.16, - template_diff_thresh=0.25, - censored_period_ms=0.0, - refractory_period_ms=4.0, - sigma_smooth_ms=0.6, - adaptative_window_thresh=0.5, - firing_contamination_balance=1.5, + resolve_graph=False, + # min_spikes=1000, + # max_distance_um=150.0, + # contamination_thresh=0.2, + # corr_diff_thresh=0.16, + # template_diff_thresh=0.25, + # censored_period_ms=0.0, + # refractory_period_ms=4.0, + # sigma_smooth_ms=0.6, + # adaptative_window_thresh=0.5, + # firing_contamination_balance=1.5, extra_outputs=True, + **job_kwargs, ) if preset == "x_contaminations": - assert len(potential_merges) == num_unit_splited + assert len(merge_unit_groups) == num_unit_splited for true_pair in other_ids.values(): true_pair = tuple(true_pair) - assert true_pair in potential_merges + assert true_pair in merge_unit_groups else: # when preset is None you have to specify the steps with pytest.raises(ValueError): - potential_merges = get_potential_auto_merge(sorting_analyzer, preset=preset) - potential_merges = get_potential_auto_merge( - sorting_analyzer, preset=preset, steps=["min_spikes", "min_snr", "remove_contaminated", "unit_positions"] + merge_unit_groups = compute_merge_unit_groups(sorting_analyzer, preset=preset) + merge_unit_groups = compute_merge_unit_groups( + sorting_analyzer, + preset=preset, + steps=["num_spikes", "snr", "remove_contaminated", "unit_locations"], + **job_kwargs, ) # DEBUG @@ -93,7 +99,7 @@ def test_get_auto_merge_list(sorting_analyzer_for_curation, preset): # m = correlograms.shape[2] // 2 - # for unit_id1, unit_id2 in potential_merges[:5]: + # for unit_id1, unit_id2 in merge_unit_groups[:5]: # unit_ind1 = sorting_with_split.id_to_index(unit_id1) # unit_ind2 = sorting_with_split.id_to_index(unit_id2) @@ -129,4 +135,6 @@ def test_get_auto_merge_list(sorting_analyzer_for_curation, preset): if __name__ == "__main__": sorting_analyzer = make_sorting_analyzer(sparse=True) - test_get_auto_merge_list(sorting_analyzer) + # preset = "x_contaminations" + preset = None + test_compute_merge_unit_groups(sorting_analyzer, preset=preset) diff --git a/src/spikeinterface/exporters/report.py b/src/spikeinterface/exporters/report.py index 3a4be9213a..ab08401382 100644 --- a/src/spikeinterface/exporters/report.py +++ b/src/spikeinterface/exporters/report.py @@ -20,10 +20,10 @@ def export_report( **job_kwargs, ): """ - Exports a SI spike sorting report. The report includes summary figures of the spike sorting output - (e.g. amplitude distributions, unit localization and depth VS amplitude) as well as unit-specific reports, - that include waveforms, templates, template maps, ISI distributions, and more. - + Exports a SI spike sorting report. The report includes summary figures of the spike sorting output. + What is plotted depends on what has been calculated. Unit locations and unit waveforms are always included. + Unit waveform densities, correlograms and spike amplitudes are plotted if `waveforms`, `correlograms`, + and `spike_amplitudes` have been computed for the given `sorting_analyzer`. Parameters ---------- diff --git a/src/spikeinterface/postprocessing/localization_tools.py b/src/spikeinterface/postprocessing/localization_tools.py index e6278fc59f..59c12a9923 100644 --- a/src/spikeinterface/postprocessing/localization_tools.py +++ b/src/spikeinterface/postprocessing/localization_tools.py @@ -3,6 +3,8 @@ import warnings import numpy as np +from spikeinterface.core import SortingAnalyzer, Templates, compute_sparsity +from spikeinterface.core.template_tools import _get_nbefore, get_dense_templates_array, get_template_extremum_channel try: import numba @@ -12,10 +14,6 @@ HAVE_NUMBA = False -from spikeinterface.core import compute_sparsity, SortingAnalyzer, Templates -from spikeinterface.core.template_tools import get_template_extremum_channel, _get_nbefore, get_dense_templates_array - - def compute_monopolar_triangulation( sorting_analyzer_or_templates: SortingAnalyzer | Templates, unit_ids=None, @@ -77,7 +75,11 @@ def compute_monopolar_triangulation( contact_locations = sorting_analyzer_or_templates.get_channel_locations() - sparsity = compute_sparsity(sorting_analyzer_or_templates, method="radius", radius_um=radius_um) + if sorting_analyzer_or_templates.sparsity is None: + sparsity = compute_sparsity(sorting_analyzer_or_templates, method="radius", radius_um=radius_um) + else: + sparsity = sorting_analyzer_or_templates.sparsity + templates = get_dense_templates_array( sorting_analyzer_or_templates, return_scaled=get_return_scaled(sorting_analyzer_or_templates) ) @@ -106,7 +108,7 @@ def compute_monopolar_triangulation( # wf is (nsample, nchan) - chann is only nieghboor wf = templates[i, :, :][:, chan_inds] if feature == "ptp": - wf_data = wf.ptp(axis=0) + wf_data = np.ptp(wf, axis=0) elif feature == "energy": wf_data = np.linalg.norm(wf, axis=0) elif feature == "peak_voltage": @@ -157,9 +159,13 @@ def compute_center_of_mass( assert feature in ["ptp", "mean", "energy", "peak_voltage"], f"{feature} is not a valid feature" - sparsity = compute_sparsity( - sorting_analyzer_or_templates, peak_sign=peak_sign, method="radius", radius_um=radius_um - ) + if sorting_analyzer_or_templates.sparsity is None: + sparsity = compute_sparsity( + sorting_analyzer_or_templates, peak_sign=peak_sign, method="radius", radius_um=radius_um + ) + else: + sparsity = sorting_analyzer_or_templates.sparsity + templates = get_dense_templates_array( sorting_analyzer_or_templates, return_scaled=get_return_scaled(sorting_analyzer_or_templates) ) @@ -180,7 +186,7 @@ def compute_center_of_mass( wf = templates[i, :, :] if feature == "ptp": - wf_data = (wf[:, chan_inds]).ptp(axis=0) + wf_data = np.ptp(wf[:, chan_inds], axis=0) elif feature == "mean": wf_data = (wf[:, chan_inds]).mean(axis=0) elif feature == "energy": @@ -650,8 +656,55 @@ def get_convolution_weights( enforce_decrease_shells = numba.jit(enforce_decrease_shells_data, nopython=True) +def compute_location_max_channel( + templates_or_sorting_analyzer: SortingAnalyzer | Templates, + unit_ids=None, + peak_sign: "neg" | "pos" | "both" = "neg", + mode: "extremum" | "at_index" | "peak_to_peak" = "extremum", +) -> np.ndarray: + """ + Localize a unit using max channel. + + This uses internally `get_template_extremum_channel()` + + + Parameters + ---------- + templates_or_sorting_analyzer : SortingAnalyzer | Templates + A SortingAnalyzer or Templates object + unit_ids: list[str] | list[int] | None + A list of unit_id to restrict the computation + peak_sign : "neg" | "pos" | "both" + Sign of the template to find extremum channels + mode : "extremum" | "at_index" | "peak_to_peak", default: "at_index" + Where the amplitude is computed + * "extremum" : take the peak value (max or min depending on `peak_sign`) + * "at_index" : take value at `nbefore` index + * "peak_to_peak" : take the peak-to-peak amplitude + + Returns + ------- + unit_locations: np.ndarray + 2d + """ + extremum_channels_index = get_template_extremum_channel( + templates_or_sorting_analyzer, peak_sign=peak_sign, mode=mode, outputs="index" + ) + contact_locations = templates_or_sorting_analyzer.get_channel_locations() + if unit_ids is None: + unit_ids = templates_or_sorting_analyzer.unit_ids + else: + unit_ids = np.asarray(unit_ids) + unit_locations = np.zeros((unit_ids.size, 2), dtype="float32") + for i, unit_id in enumerate(unit_ids): + unit_locations[i, :] = contact_locations[extremum_channels_index[unit_id]] + + return unit_locations + + _unit_location_methods = { "center_of_mass": compute_center_of_mass, "grid_convolution": compute_grid_convolution, "monopolar_triangulation": compute_monopolar_triangulation, + "max_channel": compute_location_max_channel, } diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 306e9594b8..6e7bcf21b8 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -97,7 +97,7 @@ class ComputeTemplateMetrics(AnalyzerExtension): extension_name = "template_metrics" depend_on = ["templates"] - need_recording = True + need_recording = False use_nodepipeline = False need_job_kwargs = False diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index cfa9d89fea..6c30e2730b 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -44,7 +44,7 @@ class ComputeTemplateSimilarity(AnalyzerExtension): extension_name = "template_similarity" depend_on = ["templates"] - need_recording = True + need_recording = False use_nodepipeline = False need_job_kwargs = False need_backward_compatibility_on_load = True diff --git a/src/spikeinterface/postprocessing/tests/test_unit_locations.py b/src/spikeinterface/postprocessing/tests/test_unit_locations.py index c40a917a2b..545edb3497 100644 --- a/src/spikeinterface/postprocessing/tests/test_unit_locations.py +++ b/src/spikeinterface/postprocessing/tests/test_unit_locations.py @@ -13,6 +13,7 @@ class TestUnitLocationsExtension(AnalyzerExtensionCommonTestSuite): dict(method="grid_convolution", radius_um=150, weight_method={"mode": "gaussian_2d"}), dict(method="monopolar_triangulation", radius_um=150), dict(method="monopolar_triangulation", radius_um=150, optimizer="minimize_with_log_penality"), + dict(method="max_channel"), ], ) def test_extension(self, params): diff --git a/src/spikeinterface/postprocessing/unit_locations.py b/src/spikeinterface/postprocessing/unit_locations.py index 4029fc88c7..3f6dd47eec 100644 --- a/src/spikeinterface/postprocessing/unit_locations.py +++ b/src/spikeinterface/postprocessing/unit_locations.py @@ -39,7 +39,7 @@ class ComputeUnitLocations(AnalyzerExtension): extension_name = "unit_locations" depend_on = ["templates"] - need_recording = True + need_recording = False use_nodepipeline = False need_job_kwargs = False need_backward_compatibility_on_load = True diff --git a/src/spikeinterface/preprocessing/decimate.py b/src/spikeinterface/preprocessing/decimate.py index 334ebb02d2..d5fc9d2025 100644 --- a/src/spikeinterface/preprocessing/decimate.py +++ b/src/spikeinterface/preprocessing/decimate.py @@ -63,18 +63,15 @@ def __init__( f"Consider combining DecimateRecording with FrameSliceRecording for fine control on the recording start/end frames." ) self._decimation_offset = decimation_offset - resample_rate = self._orig_samp_freq / self._decimation_factor + decimated_sampling_frequency = self._orig_samp_freq / self._decimation_factor - BasePreprocessor.__init__(self, recording, sampling_frequency=resample_rate) + BasePreprocessor.__init__(self, recording, sampling_frequency=decimated_sampling_frequency) - # in case there was a time_vector, it will be dropped for sanity. - # This is not necessary but consistent with ResampleRecording for parent_segment in recording._recording_segments: - parent_segment.time_vector = None self.add_recording_segment( DecimateRecordingSegment( parent_segment, - resample_rate, + decimated_sampling_frequency, self._orig_samp_freq, decimation_factor, decimation_offset, @@ -93,22 +90,26 @@ class DecimateRecordingSegment(BaseRecordingSegment): def __init__( self, parent_recording_segment, - resample_rate, + decimated_sampling_frequency, parent_rate, decimation_factor, decimation_offset, dtype, ): - if parent_recording_segment.t_start is None: - new_t_start = None + if parent_recording_segment.time_vector is not None: + time_vector = parent_recording_segment.time_vector[decimation_offset::decimation_factor] + decimated_sampling_frequency = None + t_start = None else: - new_t_start = parent_recording_segment.t_start + decimation_offset / parent_rate + time_vector = None + if parent_recording_segment.t_start is None: + t_start = None + else: + t_start = parent_recording_segment.t_start + (decimation_offset / parent_rate) # Do not use BasePreprocessorSegment bcause we have to reset the sampling rate! BaseRecordingSegment.__init__( - self, - sampling_frequency=resample_rate, - t_start=new_t_start, + self, sampling_frequency=decimated_sampling_frequency, t_start=t_start, time_vector=time_vector ) self._parent_segment = parent_recording_segment self._decimation_factor = decimation_factor diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index 8f38f01469..85169011d8 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -71,8 +71,10 @@ def __init__(self, recording, list_periods, mode="zeros", noise_levels=None, see if mode in ["noise"]: if noise_levels is None: + random_slices_kwargs = random_chunk_kwargs.copy() + random_slices_kwargs["seed"] = seed noise_levels = get_noise_levels( - recording, return_scaled=False, concatenated=True, seed=seed, **random_chunk_kwargs + recording, return_scaled=False, random_slices_kwargs=random_slices_kwargs ) noise_generator = NoiseGeneratorRecording( num_channels=recording.get_num_channels(), diff --git a/src/spikeinterface/preprocessing/tests/test_decimate.py b/src/spikeinterface/preprocessing/tests/test_decimate.py index 100972f762..aab17560a6 100644 --- a/src/spikeinterface/preprocessing/tests/test_decimate.py +++ b/src/spikeinterface/preprocessing/tests/test_decimate.py @@ -8,19 +8,14 @@ import numpy as np -@pytest.mark.parametrize("N_segments", [1, 2]) -@pytest.mark.parametrize("decimation_offset", [0, 1, 9, 10, 11, 100, 101]) -@pytest.mark.parametrize("decimation_factor", [1, 9, 10, 11, 100, 101]) -@pytest.mark.parametrize("start_frame", [0, 1, 5, None, 1000]) -@pytest.mark.parametrize("end_frame", [0, 1, 5, None, 1000]) -def test_decimate(N_segments, decimation_offset, decimation_factor, start_frame, end_frame): - rec = generate_recording() - - segment_num_samps = [101 + i for i in range(N_segments)] - +@pytest.mark.parametrize("num_segments", [1, 2]) +@pytest.mark.parametrize("decimation_offset", [0, 1, 5, 21, 101]) +@pytest.mark.parametrize("decimation_factor", [1, 7, 50]) +def test_decimate(num_segments, decimation_offset, decimation_factor): + segment_num_samps = [20000, 40000] rec = NumpyRecording([np.arange(2 * N).reshape(N, 2) for N in segment_num_samps], 1) - parent_traces = [rec.get_traces(i) for i in range(N_segments)] + parent_traces = [rec.get_traces(i) for i in range(num_segments)] if decimation_offset >= min(segment_num_samps) or decimation_offset >= decimation_factor: with pytest.raises(ValueError): @@ -28,19 +23,59 @@ def test_decimate(N_segments, decimation_offset, decimation_factor, start_frame, return decimated_rec = DecimateRecording(rec, decimation_factor, decimation_offset=decimation_offset) - decimated_parent_traces = [parent_traces[i][decimation_offset::decimation_factor] for i in range(N_segments)] + decimated_parent_traces = [parent_traces[i][decimation_offset::decimation_factor] for i in range(num_segments)] - if start_frame is None: - start_frame = max(decimated_rec.get_num_samples(i) for i in range(N_segments)) - if end_frame is None: - end_frame = max(decimated_rec.get_num_samples(i) for i in range(N_segments)) + for start_frame in [0, 1, 5, None, 1000]: + for end_frame in [0, 1, 5, None, 1000]: + if start_frame is None: + start_frame = max(decimated_rec.get_num_samples(i) for i in range(num_segments)) + if end_frame is None: + end_frame = max(decimated_rec.get_num_samples(i) for i in range(num_segments)) - for i in range(N_segments): + for i in range(num_segments): + assert decimated_rec.get_num_samples(i) == decimated_parent_traces[i].shape[0] + assert np.all( + decimated_rec.get_traces(i, start_frame, end_frame) + == decimated_parent_traces[i][start_frame:end_frame] + ) + + for i in range(num_segments): assert decimated_rec.get_num_samples(i) == decimated_parent_traces[i].shape[0] assert np.all( decimated_rec.get_traces(i, start_frame, end_frame) == decimated_parent_traces[i][start_frame:end_frame] ) +def test_decimate_with_times(): + rec = generate_recording(durations=[5, 10]) + + # test with times + times = [rec.get_times(0) + 10, rec.get_times(1) + 20] + for i, t in enumerate(times): + rec.set_times(t, i) + + decimation_factor = 2 + decimation_offset = 1 + decimated_rec = DecimateRecording(rec, decimation_factor, decimation_offset=decimation_offset) + + for segment_index in range(rec.get_num_segments()): + assert np.allclose( + decimated_rec.get_times(segment_index), + rec.get_times(segment_index)[decimation_offset::decimation_factor], + ) + + # test with t_start + rec = generate_recording(durations=[5, 10]) + t_starts = [10, 20] + for t_start, rec_segment in zip(t_starts, rec._recording_segments): + rec_segment.t_start = t_start + decimated_rec = DecimateRecording(rec, decimation_factor, decimation_offset=decimation_offset) + for segment_index in range(rec.get_num_segments()): + assert np.allclose( + decimated_rec.get_times(segment_index), + rec.get_times(segment_index)[decimation_offset::decimation_factor], + ) + + if __name__ == "__main__": test_decimate() diff --git a/src/spikeinterface/preprocessing/tests/test_scaling.py b/src/spikeinterface/preprocessing/tests/test_scaling.py index 321d7c9df2..e32d96901e 100644 --- a/src/spikeinterface/preprocessing/tests/test_scaling.py +++ b/src/spikeinterface/preprocessing/tests/test_scaling.py @@ -55,11 +55,11 @@ def test_scaling_in_preprocessing_chain(): recording.set_channel_gains(gains) recording.set_channel_offsets(offsets) - centered_recording = CenterRecording(scale_to_uV(recording=recording)) + centered_recording = CenterRecording(scale_to_uV(recording=recording), seed=2205) traces_scaled_with_argument = centered_recording.get_traces(return_scaled=True) # Chain preprocessors - centered_recording_scaled = CenterRecording(scale_to_uV(recording=recording)) + centered_recording_scaled = CenterRecording(scale_to_uV(recording=recording), seed=2205) traces_scaled_with_preprocessor = centered_recording_scaled.get_traces() np.testing.assert_allclose(traces_scaled_with_argument, traces_scaled_with_preprocessor) @@ -68,3 +68,8 @@ def test_scaling_in_preprocessing_chain(): traces_scaled_with_preprocessor_and_argument = centered_recording_scaled.get_traces(return_scaled=True) np.testing.assert_allclose(traces_scaled_with_preprocessor, traces_scaled_with_preprocessor_and_argument) + + +if __name__ == "__main__": + test_scale_to_uV() + test_scaling_in_preprocessing_chain() diff --git a/src/spikeinterface/preprocessing/tests/test_silence.py b/src/spikeinterface/preprocessing/tests/test_silence.py index 6c2e8ec8b5..20d4f6dfc7 100644 --- a/src/spikeinterface/preprocessing/tests/test_silence.py +++ b/src/spikeinterface/preprocessing/tests/test_silence.py @@ -9,6 +9,8 @@ import numpy as np +from pathlib import Path + def test_silence(create_cache_folder): @@ -46,4 +48,5 @@ def test_silence(create_cache_folder): if __name__ == "__main__": - test_silence() + cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" + test_silence(cache_folder) diff --git a/src/spikeinterface/preprocessing/tests/test_whiten.py b/src/spikeinterface/preprocessing/tests/test_whiten.py index 04b731de4f..b40627d836 100644 --- a/src/spikeinterface/preprocessing/tests/test_whiten.py +++ b/src/spikeinterface/preprocessing/tests/test_whiten.py @@ -5,13 +5,15 @@ from spikeinterface.preprocessing import whiten, scale, compute_whitening_matrix +from pathlib import Path + def test_whiten(create_cache_folder): cache_folder = create_cache_folder rec = generate_recording(num_channels=4, seed=2205) print(rec.get_channel_locations()) - random_chunk_kwargs = {} + random_chunk_kwargs = {"seed": 2205} W1, M = compute_whitening_matrix(rec, "global", random_chunk_kwargs, apply_mean=False, radius_um=None) # print(W) # print(M) @@ -47,4 +49,5 @@ def test_whiten(create_cache_folder): if __name__ == "__main__": - test_whiten() + cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" + test_whiten(cache_folder) diff --git a/src/spikeinterface/preprocessing/whiten.py b/src/spikeinterface/preprocessing/whiten.py index 195969ff79..57400c1199 100644 --- a/src/spikeinterface/preprocessing/whiten.py +++ b/src/spikeinterface/preprocessing/whiten.py @@ -19,6 +19,8 @@ class WhitenRecording(BasePreprocessor): recording : RecordingExtractor The recording extractor to be whitened. dtype : None or dtype, default: None + Datatype of the output recording (covariance matrix estimation + and whitening are performed in float32). If None the the parent dtype is kept. For integer dtype a int_scale must be also given. mode : "global" | "local", default: "global" @@ -74,7 +76,9 @@ def __init__( dtype_ = fix_dtype(recording, dtype) if dtype_.kind == "i": - assert int_scale is not None, "For recording with dtype=int you must set dtype=float32 OR set a int_scale" + assert ( + int_scale is not None + ), "For recording with dtype=int you must set the output dtype to float OR set a int_scale" if W is not None: W = np.asarray(W) @@ -124,7 +128,7 @@ def __init__(self, parent_recording_segment, W, M, dtype, int_scale): def get_traces(self, start_frame, end_frame, channel_indices): traces = self.parent_recording_segment.get_traces(start_frame, end_frame, slice(None)) traces_dtype = traces.dtype - # if uint --> force int + # if uint --> force float if traces_dtype.kind == "u": traces = traces.astype("float32") @@ -185,6 +189,7 @@ def compute_whitening_matrix( """ random_data = get_random_data_chunks(recording, concatenated=True, return_scaled=False, **random_chunk_kwargs) + random_data = random_data.astype(np.float32) regularize_kwargs = regularize_kwargs if regularize_kwargs is not None else {"method": "GraphicalLassoCV"} diff --git a/src/spikeinterface/sorters/basesorter.py b/src/spikeinterface/sorters/basesorter.py index 3502d27548..c59fa29c05 100644 --- a/src/spikeinterface/sorters/basesorter.py +++ b/src/spikeinterface/sorters/basesorter.py @@ -145,9 +145,10 @@ def initialize_folder(cls, recording, output_folder, verbose, remove_existing_fo elif recording.check_serializability("pickle"): recording.dump(output_folder / "spikeinterface_recording.pickle", relative_to=output_folder) else: - # TODO: deprecate and finally remove this after 0.100 - d = {"warning": "The recording is not serializable to json"} - rec_file.write_text(json.dumps(d, indent=4), encoding="utf8") + raise RuntimeError( + "This recording is not serializable and so can not be sorted. Consider `recording.save()` to save a " + "compatible binary file." + ) return output_folder diff --git a/src/spikeinterface/sortingcomponents/motion/dredge.py b/src/spikeinterface/sortingcomponents/motion/dredge.py index e2b6b1a2bc..bfedd4e1ee 100644 --- a/src/spikeinterface/sortingcomponents/motion/dredge.py +++ b/src/spikeinterface/sortingcomponents/motion/dredge.py @@ -22,20 +22,19 @@ """ +import gc import warnings -from tqdm.auto import trange import numpy as np - -import gc +from tqdm.auto import trange from .motion_utils import ( Motion, + get_spatial_bin_edges, get_spatial_windows, get_window_domains, - scipy_conv1d, make_2d_motion_histogram, - get_spatial_bin_edges, + scipy_conv1d, ) @@ -979,7 +978,7 @@ def xcorr_windows( if max_disp_um is None: if rigid: - max_disp_um = int(spatial_bin_edges_um.ptp() // 4) + max_disp_um = int(np.ptp(spatial_bin_edges_um) // 4) else: max_disp_um = int(win_scale_um // 4) diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py index 5b1d33b334..d03744f8f9 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection.py @@ -57,6 +57,7 @@ def detect_peaks( folder=None, names=None, skip_after_n_peaks=None, + recording_slices=None, **kwargs, ): """Peak detection based on threshold crossing in term of k x MAD. @@ -83,6 +84,10 @@ def detect_peaks( skip_after_n_peaks : None | int Skip the computation after n_peaks. This is not an exact because internally this skip is done per worker in average. + recording_slices : None | list[tuple] + Optionaly give a list of slices to run the pipeline only on some chunks of the recording. + It must be a list of (segment_index, frame_start, frame_stop). + If None (default), the function iterates over the entire duration of the recording. {method_doc} {job_doc} @@ -135,6 +140,7 @@ def detect_peaks( folder=folder, names=names, skip_after_n_peaks=skip_after_n_peaks, + recording_slices=recording_slices, ) return outs diff --git a/src/spikeinterface/widgets/autocorrelograms.py b/src/spikeinterface/widgets/autocorrelograms.py index c8acd93dc2..c211a277f8 100644 --- a/src/spikeinterface/widgets/autocorrelograms.py +++ b/src/spikeinterface/widgets/autocorrelograms.py @@ -9,7 +9,13 @@ class AutoCorrelogramsWidget(CrossCorrelogramsWidget): # the doc is copied form CrossCorrelogramsWidget def __init__(self, *args, **kargs): - CrossCorrelogramsWidget.__init__(self, *args, **kargs) + _ = kargs.pop("min_similarity_for_correlograms", 0.0) + CrossCorrelogramsWidget.__init__( + self, + *args, + **kargs, + min_similarity_for_correlograms=None, + ) def plot_matplotlib(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt diff --git a/src/spikeinterface/widgets/crosscorrelograms.py b/src/spikeinterface/widgets/crosscorrelograms.py index cdb2041aa3..88dd803323 100644 --- a/src/spikeinterface/widgets/crosscorrelograms.py +++ b/src/spikeinterface/widgets/crosscorrelograms.py @@ -21,7 +21,8 @@ class CrossCorrelogramsWidget(BaseWidget): List of unit ids min_similarity_for_correlograms : float, default: 0.2 For sortingview backend. Threshold for computing pair-wise cross-correlograms. - If template similarity between two units is below this threshold, the cross-correlogram is not displayed + If template similarity between two units is below this threshold, the cross-correlogram is not displayed. + For auto-correlograms plot, this is automatically set to None. window_ms : float, default: 100.0 Window for CCGs in ms. If correlograms are already computed (e.g. with SortingAnalyzer), this argument is ignored diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index 86f2350a85..f5dadc780f 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -52,6 +52,8 @@ class TracesWidget(BaseWidget): If dict, keys should be the same as recording keys scale : float, default: 1 Scale factor for the traces + vspacing_factor : float, default: 1.5 + Vertical spacing between channels as a multiple of maximum channel amplitude with_colorbar : bool, default: True When mode is "map", a colorbar is added tile_size : int, default: 1500 @@ -82,6 +84,7 @@ def __init__( tile_size=1500, seconds_per_row=0.2, scale=1, + vspacing_factor=1.5, with_colorbar=True, add_legend=True, backend=None, @@ -168,7 +171,7 @@ def __init__( traces0 = list_traces[0] mean_channel_std = np.mean(np.std(traces0, axis=0)) max_channel_amp = np.max(np.max(np.abs(traces0), axis=0)) - vspacing = max_channel_amp * 1.5 + vspacing = max_channel_amp * vspacing_factor if rec0.get_channel_groups() is None: color_groups = False diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index 755e60ccbf..9466110110 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -107,82 +107,90 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): # and use custum grid spec fig = self.figure nrows = 2 - ncols = 3 - if sorting_analyzer.has_extension("correlograms") or sorting_analyzer.has_extension("spike_amplitudes"): + ncols = 2 + if sorting_analyzer.has_extension("correlograms"): + ncols += 1 + if sorting_analyzer.has_extension("waveforms"): ncols += 1 if sorting_analyzer.has_extension("spike_amplitudes"): nrows += 1 gs = fig.add_gridspec(nrows, ncols) + col_counter = 0 - if sorting_analyzer.has_extension("unit_locations"): - ax1 = fig.add_subplot(gs[:2, 0]) - # UnitLocationsPlotter().do_plot(dp.plot_data_unit_locations, ax=ax1) - w = UnitLocationsWidget( - sorting_analyzer, - unit_ids=[unit_id], - unit_colors=unit_colors, - plot_legend=False, - backend="matplotlib", - ax=ax1, - **unitlocationswidget_kwargs, - ) - - unit_locations = sorting_analyzer.get_extension("unit_locations").get_data(outputs="by_unit") - unit_location = unit_locations[unit_id] - x, y = unit_location[0], unit_location[1] - ax1.set_xlim(x - 80, x + 80) - ax1.set_ylim(y - 250, y + 250) - ax1.set_xticks([]) - ax1.set_xlabel(None) - ax1.set_ylabel(None) - - ax2 = fig.add_subplot(gs[:2, 1]) - w = UnitWaveformsWidget( + # Unit locations and unit waveform plots are always generated + ax_unit_locations = fig.add_subplot(gs[:2, col_counter]) + _ = UnitLocationsWidget( + sorting_analyzer, + unit_ids=[unit_id], + unit_colors=unit_colors, + plot_legend=False, + backend="matplotlib", + ax=ax_unit_locations, + **unitlocationswidget_kwargs, + ) + col_counter += 1 + + unit_locations = sorting_analyzer.get_extension("unit_locations").get_data(outputs="by_unit") + unit_location = unit_locations[unit_id] + x, y = unit_location[0], unit_location[1] + ax_unit_locations.set_xlim(x - 80, x + 80) + ax_unit_locations.set_ylim(y - 250, y + 250) + ax_unit_locations.set_xticks([]) + ax_unit_locations.set_xlabel(None) + ax_unit_locations.set_ylabel(None) + + ax_unit_waveforms = fig.add_subplot(gs[:2, col_counter]) + _ = UnitWaveformsWidget( sorting_analyzer, unit_ids=[unit_id], unit_colors=unit_colors, plot_templates=True, + plot_waveforms=sorting_analyzer.has_extension("waveforms"), same_axis=True, plot_legend=False, sparsity=sparsity, backend="matplotlib", - ax=ax2, + ax=ax_unit_waveforms, **unitwaveformswidget_kwargs, ) + col_counter += 1 - ax2.set_title(None) + ax_unit_waveforms.set_title(None) - ax3 = fig.add_subplot(gs[:2, 2]) - UnitWaveformDensityMapWidget( - sorting_analyzer, - unit_ids=[unit_id], - unit_colors=unit_colors, - use_max_channel=True, - same_axis=False, - backend="matplotlib", - ax=ax3, - **unitwaveformdensitymapwidget_kwargs, - ) - ax3.set_ylabel(None) + if sorting_analyzer.has_extension("waveforms"): + ax_waveform_density = fig.add_subplot(gs[:2, col_counter]) + UnitWaveformDensityMapWidget( + sorting_analyzer, + unit_ids=[unit_id], + unit_colors=unit_colors, + use_max_channel=True, + same_axis=False, + backend="matplotlib", + ax=ax_waveform_density, + **unitwaveformdensitymapwidget_kwargs, + ) + col_counter += 1 + ax_waveform_density.set_ylabel(None) if sorting_analyzer.has_extension("correlograms"): - ax4 = fig.add_subplot(gs[:2, 3]) + ax_correlograms = fig.add_subplot(gs[:2, col_counter]) AutoCorrelogramsWidget( sorting_analyzer, unit_ids=[unit_id], unit_colors=unit_colors, backend="matplotlib", - ax=ax4, + ax=ax_correlograms, **autocorrelogramswidget_kwargs, ) + col_counter += 1 - ax4.set_title(None) - ax4.set_yticks([]) + ax_correlograms.set_title(None) + ax_correlograms.set_yticks([]) if sorting_analyzer.has_extension("spike_amplitudes"): - ax5 = fig.add_subplot(gs[2, :3]) - ax6 = fig.add_subplot(gs[2, 3]) - axes = np.array([ax5, ax6]) + ax_spike_amps = fig.add_subplot(gs[2, : col_counter - 1]) + ax_amps_distribution = fig.add_subplot(gs[2, col_counter - 1]) + axes = np.array([ax_spike_amps, ax_amps_distribution]) AmplitudesWidget( sorting_analyzer, unit_ids=[unit_id],