diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 6d9d2a827f..7ca527e255 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -1,4 +1,5 @@ from __future__ import annotations + import warnings from pathlib import Path @@ -7,14 +8,9 @@ from .base import BaseSegment from .baserecordingsnippets import BaseRecordingSnippets -from .core_tools import ( - convert_bytes_to_str, - convert_seconds_to_str, -) -from .recording_tools import write_binary_recording - - +from .core_tools import convert_bytes_to_str, convert_seconds_to_str from .job_tools import split_job_kwargs +from .recording_tools import write_binary_recording class BaseRecording(BaseRecordingSnippets): @@ -950,11 +946,11 @@ def time_to_sample_index(self, time_s): sample_index = time_s * self.sampling_frequency else: sample_index = (time_s - self.t_start) * self.sampling_frequency - sample_index = round(sample_index) + sample_index = np.round(sample_index).astype(int) else: sample_index = np.searchsorted(self.time_vector, time_s, side="right") - 1 - return int(sample_index) + return sample_index def get_num_samples(self) -> int: """Returns the number of samples in this signal segment diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 6e7bcf21b8..1969480503 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -63,23 +63,10 @@ class ComputeTemplateMetrics(AnalyzerExtension): include_multi_channel_metrics : bool, default: False Whether to compute multi-channel metrics delete_existing_metrics : bool, default: False - If True, any template metrics attached to the `sorting_analyzer` are deleted. If False, any metrics which were previously calculated but are not included in `metric_names` are kept, provided the `metrics_kwargs` are unchanged. - metrics_kwargs : dict - Additional arguments to pass to the metric functions. Including: - * recovery_window_ms: the window in ms after the peak to compute the recovery_slope, default: 0.7 - * peak_relative_threshold: the relative threshold to detect positive and negative peaks, default: 0.2 - * peak_width_ms: the width in samples to detect peaks, default: 0.2 - * depth_direction: the direction to compute velocity above and below, default: "y" (see notes) - * min_channels_for_velocity: the minimum number of channels above or below to compute velocity, default: 5 - * min_r2_velocity: the minimum r2 to accept the velocity fit, default: 0.7 - * exp_peak_function: the function to use to compute the peak amplitude for the exp decay, default: "ptp" - * min_r2_exp_decay: the minimum r2 to accept the exp decay fit, default: 0.5 - * spread_threshold: the threshold to compute the spread, default: 0.2 - * spread_smooth_um: the smoothing in um to compute the spread, default: 20 - * column_range: the range in um in the horizontal direction to consider channels for velocity, default: None - - If None, all channels all channels are considered - - If 0 or 1, only the "column" that includes the max channel is considered - - If > 1, only channels within range (+/-) um from the max channel horizontal position are used + If True, any template metrics attached to the `sorting_analyzer` are deleted. If False, any metrics which were previously calculated but are not included in `metric_names` are kept, provided the `metric_params` are unchanged. + metric_params : dict of dicts or None, default: None + Dictionary with parameters for template metrics calculation. + Default parameters can be obtained with: `si.postprocessing.template_metrics.get_default_tm_params()` Returns ------- @@ -100,15 +87,29 @@ class ComputeTemplateMetrics(AnalyzerExtension): need_recording = False use_nodepipeline = False need_job_kwargs = False + need_backward_compatibility_on_load = True min_channels_for_multi_channel_warning = 10 + def _handle_backward_compatibility_on_load(self): + + # For backwards compatibility - this reformats metrics_kwargs as metric_params + if (metrics_kwargs := self.params.get("metrics_kwargs")) is not None: + + metric_params = {} + for metric_name in self.params["metric_names"]: + metric_params[metric_name] = deepcopy(metrics_kwargs) + self.params["metric_params"] = metric_params + + del self.params["metrics_kwargs"] + def _set_params( self, metric_names=None, peak_sign="neg", upsampling_factor=10, sparsity=None, + metric_params=None, metrics_kwargs=None, include_multi_channel_metrics=False, delete_existing_metrics=False, @@ -134,33 +135,24 @@ def _set_params( if include_multi_channel_metrics: metric_names += get_multi_channel_template_metric_names() - if metrics_kwargs is None: - metrics_kwargs_ = _default_function_kwargs.copy() - if len(other_kwargs) > 0: - for m in other_kwargs: - if m in metrics_kwargs_: - metrics_kwargs_[m] = other_kwargs[m] - else: - metrics_kwargs_ = _default_function_kwargs.copy() - metrics_kwargs_.update(metrics_kwargs) + if metrics_kwargs is not None and metric_params is None: + deprecation_msg = "`metrics_kwargs` is deprecated and will be removed in version 0.104.0. Please use `metric_params` instead" + deprecation_msg = "`metrics_kwargs` is deprecated and will be removed in version 0.104.0. Please use `metric_params` instead" + + metric_params = {} + for metric_name in metric_names: + metric_params[metric_name] = deepcopy(metrics_kwargs) + + metric_params_ = get_default_tm_params(metric_names) + for k in metric_params_: + if metric_params is not None and k in metric_params: + metric_params_[k].update(metric_params[k]) metrics_to_compute = metric_names tm_extension = self.sorting_analyzer.get_extension("template_metrics") if delete_existing_metrics is False and tm_extension is not None: - existing_params = tm_extension.params["metrics_kwargs"] - # checks that existing metrics were calculated using the same params - if existing_params != metrics_kwargs_: - warnings.warn( - f"The parameters used to calculate the previous template metrics are different" - f"than those used now.\nPrevious parameters: {existing_params}\nCurrent " - f"parameters: {metrics_kwargs_}\nDeleting previous template metrics..." - ) - tm_extension.params["metric_names"] = [] - existing_metric_names = [] - else: - existing_metric_names = tm_extension.params["metric_names"] - + existing_metric_names = tm_extension.params["metric_names"] existing_metric_names_propogated = [ metric_name for metric_name in existing_metric_names if metric_name not in metrics_to_compute ] @@ -171,7 +163,7 @@ def _set_params( sparsity=sparsity, peak_sign=peak_sign, upsampling_factor=int(upsampling_factor), - metrics_kwargs=metrics_kwargs_, + metric_params=metric_params_, delete_existing_metrics=delete_existing_metrics, metrics_to_compute=metrics_to_compute, ) @@ -273,7 +265,7 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri sampling_frequency=sampling_frequency_up, trough_idx=trough_idx, peak_idx=peak_idx, - **self.params["metrics_kwargs"], + **self.params["metric_params"][metric_name], ) except Exception as e: warnings.warn(f"Error computing metric {metric_name} for unit {unit_id}: {e}") @@ -312,7 +304,7 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri template_upsampled, channel_locations=channel_locations_sparse, sampling_frequency=sampling_frequency_up, - **self.params["metrics_kwargs"], + **self.params["metric_params"][metric_name], ) except Exception as e: warnings.warn(f"Error computing metric {metric_name} for unit {unit_id}: {e}") @@ -326,8 +318,8 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri def _run(self, verbose=False): - delete_existing_metrics = self.params["delete_existing_metrics"] metrics_to_compute = self.params["metrics_to_compute"] + delete_existing_metrics = self.params["delete_existing_metrics"] # compute the metrics which have been specified by the user computed_metrics = self._compute_metrics( @@ -343,9 +335,21 @@ def _run(self, verbose=False): ): existing_metrics = tm_extension.params["metric_names"] + existing_metrics = [] + # here we get in the loaded via the dict only (to avoid full loading from disk after params reset) + tm_extension = self.sorting_analyzer.extensions.get("template_metrics", None) + if ( + delete_existing_metrics is False + and tm_extension is not None + and tm_extension.data.get("metrics") is not None + ): + existing_metrics = tm_extension.params["metric_names"] + # append the metrics which were previously computed for metric_name in set(existing_metrics).difference(metrics_to_compute): - computed_metrics[metric_name] = tm_extension.data["metrics"][metric_name] + # some metrics names produce data columns with other names. This deals with that. + for column_name in tm_compute_name_to_column_names[metric_name]: + computed_metrics[column_name] = tm_extension.data["metrics"][column_name] self.data["metrics"] = computed_metrics @@ -372,6 +376,35 @@ def _get_data(self): ) +def get_default_tm_params(metric_names): + if metric_names is None: + metric_names = get_template_metric_names() + + base_tm_params = _default_function_kwargs + + metric_params = {} + for metric_name in metric_names: + metric_params[metric_name] = deepcopy(base_tm_params) + + return metric_params + + +# a dict converting the name of the metric for computation to the output of that computation +tm_compute_name_to_column_names = { + "peak_to_valley": ["peak_to_valley"], + "peak_trough_ratio": ["peak_trough_ratio"], + "half_width": ["half_width"], + "repolarization_slope": ["repolarization_slope"], + "recovery_slope": ["recovery_slope"], + "num_positive_peaks": ["num_positive_peaks"], + "num_negative_peaks": ["num_negative_peaks"], + "velocity_above": ["velocity_above"], + "velocity_below": ["velocity_below"], + "exp_decay": ["exp_decay"], + "spread": ["spread"], +} + + def get_trough_and_peak_idx(template): """ Return the indices into the input template of the detected trough diff --git a/src/spikeinterface/postprocessing/tests/test_template_metrics.py b/src/spikeinterface/postprocessing/tests/test_template_metrics.py index 5056d4ff2a..1bf49f64c1 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_metrics.py +++ b/src/spikeinterface/postprocessing/tests/test_template_metrics.py @@ -1,5 +1,5 @@ from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite -from spikeinterface.postprocessing import ComputeTemplateMetrics +from spikeinterface.postprocessing import ComputeTemplateMetrics, compute_template_metrics import pytest import csv @@ -8,6 +8,49 @@ template_metrics = list(_single_channel_metric_name_to_func.keys()) +def test_different_params_template_metrics(small_sorting_analyzer): + """ + Computes template metrics using different params, and check that they are + actually calculated using the different params. + """ + compute_template_metrics( + sorting_analyzer=small_sorting_analyzer, + metric_names=["exp_decay", "spread", "half_width"], + metric_params={"exp_decay": {"recovery_window_ms": 0.8}, "spread": {"spread_smooth_um": 15}}, + ) + + tm_extension = small_sorting_analyzer.get_extension("template_metrics") + tm_params = tm_extension.params["metric_params"] + + assert tm_params["exp_decay"]["recovery_window_ms"] == 0.8 + assert tm_params["spread"]["recovery_window_ms"] == 0.7 + assert tm_params["half_width"]["recovery_window_ms"] == 0.7 + + assert tm_params["spread"]["spread_smooth_um"] == 15 + assert tm_params["exp_decay"]["spread_smooth_um"] == 20 + assert tm_params["half_width"]["spread_smooth_um"] == 20 + + +def test_backwards_compat_params_template_metrics(small_sorting_analyzer): + """ + Computes template metrics using the metrics_kwargs keyword + """ + compute_template_metrics( + sorting_analyzer=small_sorting_analyzer, + metric_names=["exp_decay", "spread"], + metrics_kwargs={"recovery_window_ms": 0.8}, + ) + + tm_extension = small_sorting_analyzer.get_extension("template_metrics") + tm_params = tm_extension.params["metric_params"] + + assert tm_params["exp_decay"]["recovery_window_ms"] == 0.8 + assert tm_params["spread"]["recovery_window_ms"] == 0.8 + + assert tm_params["spread"]["spread_smooth_um"] == 20 + assert tm_params["exp_decay"]["spread_smooth_um"] == 20 + + def test_compute_new_template_metrics(small_sorting_analyzer): """ Computes template metrics then computes a subset of template metrics, and checks @@ -17,6 +60,8 @@ def test_compute_new_template_metrics(small_sorting_analyzer): are deleted. """ + small_sorting_analyzer.delete_extension("template_metrics") + # calculate just exp_decay small_sorting_analyzer.compute({"template_metrics": {"metric_names": ["exp_decay"]}}) template_metric_extension = small_sorting_analyzer.get_extension("template_metrics") @@ -47,7 +92,7 @@ def test_compute_new_template_metrics(small_sorting_analyzer): # check that, when parameters are changed, the old metrics are deleted small_sorting_analyzer.compute( - {"template_metrics": {"metric_names": ["exp_decay"], "metrics_kwargs": {"recovery_window_ms": 0.6}}} + {"template_metrics": {"metric_names": ["exp_decay"], "metric_params": {"recovery_window_ms": 0.6}}} ) diff --git a/src/spikeinterface/qualitymetrics/pca_metrics.py b/src/spikeinterface/qualitymetrics/pca_metrics.py index 69624e8346..c789d1af82 100644 --- a/src/spikeinterface/qualitymetrics/pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/pca_metrics.py @@ -6,6 +6,7 @@ from copy import deepcopy import platform from tqdm.auto import tqdm +from warnings import warn import numpy as np @@ -55,6 +56,7 @@ def get_quality_pca_metric_list(): def compute_pc_metrics( sorting_analyzer, metric_names=None, + metric_params=None, qm_params=None, unit_ids=None, seed=None, @@ -73,7 +75,7 @@ def compute_pc_metrics( metric_names : list of str, default: None The list of PC metrics to compute. If not provided, defaults to all PC metrics. - qm_params : dict or None + metric_params : dict or None Dictionary with parameters for each PC metric function. unit_ids : list of int or None List of unit ids to compute metrics for. @@ -89,6 +91,14 @@ def compute_pc_metrics( pc_metrics : dict The computed PC metrics. """ + + if qm_params is not None and metric_params is None: + deprecation_msg = ( + "`qm_params` is deprecated and will be removed in version 0.104.0. Please use metric_params instead" + ) + warn(deprecation_msg, category=DeprecationWarning, stacklevel=2) + metric_params = qm_params + pca_ext = sorting_analyzer.get_extension("principal_components") assert pca_ext is not None, "calculate_pc_metrics() need extension 'principal_components'" @@ -96,8 +106,8 @@ def compute_pc_metrics( if metric_names is None: metric_names = _possible_pc_metric_names.copy() - if qm_params is None: - qm_params = _default_params + if metric_params is None: + metric_params = _default_params extremum_channels = get_template_extremum_channel(sorting_analyzer) @@ -150,7 +160,7 @@ def compute_pc_metrics( pcs = dense_projections[np.isin(all_labels, neighbor_unit_ids)][:, :, neighbor_channel_indices] pcs_flat = pcs.reshape(pcs.shape[0], -1) - func_args = (pcs_flat, labels, non_nn_metrics, unit_id, unit_ids, qm_params, max_threads_per_process) + func_args = (pcs_flat, labels, non_nn_metrics, unit_id, unit_ids, metric_params, max_threads_per_process) items.append(func_args) if not run_in_parallel and non_nn_metrics: @@ -187,7 +197,7 @@ def compute_pc_metrics( units_loop = tqdm(units_loop, desc=f"calculate {metric_name} metric", total=len(unit_ids)) func = _nn_metric_name_to_func[metric_name] - metric_params = qm_params[metric_name] if metric_name in qm_params else {} + metric_params = metric_params[metric_name] if metric_name in metric_params else {} for _, unit_id in units_loop: try: @@ -216,7 +226,7 @@ def compute_pc_metrics( def calculate_pc_metrics( - sorting_analyzer, metric_names=None, qm_params=None, unit_ids=None, seed=None, n_jobs=1, progress_bar=False + sorting_analyzer, metric_names=None, metric_params=None, unit_ids=None, seed=None, n_jobs=1, progress_bar=False ): warnings.warn( "The `calculate_pc_metrics` function is deprecated and will be removed in 0.103.0. Please use compute_pc_metrics instead", @@ -227,7 +237,7 @@ def calculate_pc_metrics( pc_metrics = compute_pc_metrics( sorting_analyzer, metric_names=metric_names, - qm_params=qm_params, + metric_params=metric_params, unit_ids=unit_ids, seed=seed, n_jobs=n_jobs, @@ -980,16 +990,16 @@ def _compute_isolation(pcs_target_unit, pcs_other_unit, n_neighbors: int): def pca_metrics_one_unit(args): - (pcs_flat, labels, metric_names, unit_id, unit_ids, qm_params, max_threads_per_process) = args + (pcs_flat, labels, metric_names, unit_id, unit_ids, metric_params, max_threads_per_process) = args if max_threads_per_process is None: - return _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, qm_params) + return _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, metric_params) else: with threadpool_limits(limits=int(max_threads_per_process)): - return _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, qm_params) + return _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, metric_params) -def _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, qm_params): +def _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, metric_params): pc_metrics = {} # metrics if "isolation_distance" in metric_names or "l_ratio" in metric_names: @@ -1018,7 +1028,7 @@ def _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, qm_ if "nearest_neighbor" in metric_names: try: nn_hit_rate, nn_miss_rate = nearest_neighbors_metrics( - pcs_flat, labels, unit_id, **qm_params["nearest_neighbor"] + pcs_flat, labels, unit_id, **metric_params["nearest_neighbor"] ) except: nn_hit_rate = np.nan @@ -1027,7 +1037,7 @@ def _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, qm_ pc_metrics["nn_miss_rate"] = nn_miss_rate if "silhouette" in metric_names: - silhouette_method = qm_params["silhouette"]["method"] + silhouette_method = metric_params["silhouette"]["method"] if "simplified" in silhouette_method: try: unit_silhouette_score = simplified_silhouette_score(pcs_flat, labels, unit_id) diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index b6a50d60f5..d71450853f 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -6,6 +6,7 @@ from copy import deepcopy import numpy as np +from warnings import warn from spikeinterface.core.job_tools import fix_job_kwargs from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension @@ -15,7 +16,7 @@ compute_pc_metrics, _misc_metric_name_to_func, _possible_pc_metric_names, - compute_name_to_column_names, + qm_compute_name_to_column_names, ) from .misc_metrics import _default_params as misc_metrics_params from .pca_metrics import _default_params as pca_metrics_params @@ -31,7 +32,7 @@ class ComputeQualityMetrics(AnalyzerExtension): A SortingAnalyzer object. metric_names : list or None List of quality metrics to compute. - qm_params : dict or None + metric_params : dict of dicts or None Dictionary with parameters for quality metrics calculation. Default parameters can be obtained with: `si.qualitymetrics.get_default_qm_params()` skip_pc_metrics : bool, default: False @@ -54,10 +55,18 @@ class ComputeQualityMetrics(AnalyzerExtension): need_recording = False use_nodepipeline = False need_job_kwargs = True + need_backward_compatibility_on_load = True + + def _handle_backward_compatibility_on_load(self): + # For backwards compatibility - this renames qm_params as metric_params + if (qm_params := self.params.get("qm_params")) is not None: + self.params["metric_params"] = qm_params + del self.params["qm_params"] def _set_params( self, metric_names=None, + metric_params=None, qm_params=None, peak_sign=None, seed=None, @@ -65,6 +74,12 @@ def _set_params( delete_existing_metrics=False, metrics_to_compute=None, ): + if qm_params is not None and metric_params is None: + deprecation_msg = ( + "`qm_params` is deprecated and will be removed in version 0.104.0 Please use metric_params instead" + ) + metric_params = qm_params + warn(deprecation_msg, category=DeprecationWarning, stacklevel=2) if metric_names is None: metric_names = list(_misc_metric_name_to_func.keys()) @@ -80,12 +95,12 @@ def _set_params( if "drift" in metric_names: metric_names.remove("drift") - qm_params_ = get_default_qm_params() - for k in qm_params_: - if qm_params is not None and k in qm_params: - qm_params_[k].update(qm_params[k]) - if "peak_sign" in qm_params_[k] and peak_sign is not None: - qm_params_[k]["peak_sign"] = peak_sign + metric_params_ = get_default_qm_params() + for k in metric_params_: + if metric_params is not None and k in metric_params: + metric_params_[k].update(metric_params[k]) + if "peak_sign" in metric_params_[k] and peak_sign is not None: + metric_params_[k]["peak_sign"] = peak_sign metrics_to_compute = metric_names qm_extension = self.sorting_analyzer.get_extension("quality_metrics") @@ -101,7 +116,7 @@ def _set_params( metric_names=metric_names, peak_sign=peak_sign, seed=seed, - qm_params=qm_params_, + metric_params=metric_params_, skip_pc_metrics=skip_pc_metrics, delete_existing_metrics=delete_existing_metrics, metrics_to_compute=metrics_to_compute, @@ -141,7 +156,7 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri """ import pandas as pd - qm_params = self.params["qm_params"] + metric_params = self.params["metric_params"] # sparsity = self.params["sparsity"] seed = self.params["seed"] @@ -177,7 +192,7 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri func = _misc_metric_name_to_func[metric_name] - params = qm_params[metric_name] if metric_name in qm_params else {} + params = metric_params[metric_name] if metric_name in metric_params else {} res = func(sorting_analyzer, unit_ids=non_empty_unit_ids, **params) # QM with uninstall dependencies might return None if res is not None: @@ -205,7 +220,7 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri # sparsity=sparsity, progress_bar=progress_bar, n_jobs=n_jobs, - qm_params=qm_params, + metric_params=metric_params, seed=seed, ) for col, values in pc_metrics.items(): @@ -246,7 +261,7 @@ def _run(self, verbose=False, **job_kwargs): # append the metrics which were previously computed for metric_name in set(existing_metrics).difference(metrics_to_compute): # some metrics names produce data columns with other names. This deals with that. - for column_name in compute_name_to_column_names[metric_name]: + for column_name in qm_compute_name_to_column_names[metric_name]: computed_metrics[column_name] = qm_extension.data["metrics"][column_name] self.data["metrics"] = computed_metrics diff --git a/src/spikeinterface/qualitymetrics/quality_metric_list.py b/src/spikeinterface/qualitymetrics/quality_metric_list.py index 375dd320ae..fc7e92b50d 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_list.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_list.py @@ -55,7 +55,7 @@ } # a dict converting the name of the metric for computation to the output of that computation -compute_name_to_column_names = { +qm_compute_name_to_column_names = { "num_spikes": ["num_spikes"], "firing_rate": ["firing_rate"], "presence_ratio": ["presence_ratio"], diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index 4c0890b62b..20869aa44a 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -69,7 +69,7 @@ def test_compute_new_quality_metrics(small_sorting_analyzer): assert calculated_metrics == ["snr"] small_sorting_analyzer.compute( - {"quality_metrics": {"metric_names": list(qm_params.keys()), "qm_params": qm_params}} + {"quality_metrics": {"metric_names": list(qm_params.keys()), "metric_params": qm_params}} ) small_sorting_analyzer.compute({"quality_metrics": {"metric_names": ["snr"]}}) @@ -96,13 +96,13 @@ def test_compute_new_quality_metrics(small_sorting_analyzer): # check that, when parameters are changed, the data and metadata are updated old_snr_data = deepcopy(quality_metric_extension.get_data()["snr"].values) small_sorting_analyzer.compute( - {"quality_metrics": {"metric_names": ["snr"], "qm_params": {"snr": {"peak_mode": "peak_to_peak"}}}} + {"quality_metrics": {"metric_names": ["snr"], "metric_params": {"snr": {"peak_mode": "peak_to_peak"}}}} ) new_quality_metric_extension = small_sorting_analyzer.get_extension("quality_metrics") new_snr_data = new_quality_metric_extension.get_data()["snr"].values assert np.all(old_snr_data != new_snr_data) - assert new_quality_metric_extension.params["qm_params"]["snr"]["peak_mode"] == "peak_to_peak" + assert new_quality_metric_extension.params["metric_params"]["snr"]["peak_mode"] == "peak_to_peak" # check that all quality metrics are deleted when parents are recomputed, even after # recomputation @@ -280,10 +280,10 @@ def test_unit_id_order_independence(small_sorting_analyzer): } quality_metrics_1 = compute_quality_metrics( - small_sorting_analyzer, metric_names=get_quality_metric_list(), qm_params=qm_params + small_sorting_analyzer, metric_names=get_quality_metric_list(), metric_params=qm_params ) quality_metrics_2 = compute_quality_metrics( - small_sorting_analyzer_2, metric_names=get_quality_metric_list(), qm_params=qm_params + small_sorting_analyzer_2, metric_names=get_quality_metric_list(), metric_params=qm_params ) for metric, metric_2_data in quality_metrics_2.items(): diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py index a6415c58e8..60f0490f51 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py @@ -24,14 +24,14 @@ def test_compute_quality_metrics(sorting_analyzer_simple): metrics = compute_quality_metrics( sorting_analyzer, metric_names=["snr"], - qm_params=dict(isi_violation=dict(isi_threshold_ms=2)), + metric_params=dict(isi_violation=dict(isi_threshold_ms=2)), skip_pc_metrics=True, seed=2205, ) # print(metrics) qm = sorting_analyzer.get_extension("quality_metrics") - assert qm.params["qm_params"]["isi_violation"]["isi_threshold_ms"] == 2 + assert qm.params["metric_params"]["isi_violation"]["isi_threshold_ms"] == 2 assert "snr" in metrics.columns assert "isolation_distance" not in metrics.columns @@ -40,7 +40,7 @@ def test_compute_quality_metrics(sorting_analyzer_simple): metrics = compute_quality_metrics( sorting_analyzer, metric_names=None, - qm_params=dict(isi_violation=dict(isi_threshold_ms=2)), + metric_params=dict(isi_violation=dict(isi_threshold_ms=2)), skip_pc_metrics=False, seed=2205, ) @@ -54,7 +54,7 @@ def test_compute_quality_metrics_recordingless(sorting_analyzer_simple): metrics = compute_quality_metrics( sorting_analyzer, metric_names=None, - qm_params=dict(isi_violation=dict(isi_threshold_ms=2)), + metric_params=dict(isi_violation=dict(isi_threshold_ms=2)), skip_pc_metrics=False, seed=2205, ) @@ -68,7 +68,7 @@ def test_compute_quality_metrics_recordingless(sorting_analyzer_simple): metrics_norec = compute_quality_metrics( sorting_analyzer_norec, metric_names=None, - qm_params=dict(isi_violation=dict(isi_threshold_ms=2)), + metric_params=dict(isi_violation=dict(isi_threshold_ms=2)), skip_pc_metrics=False, seed=2205, ) @@ -101,7 +101,7 @@ def test_empty_units(sorting_analyzer_simple): metrics_empty = compute_quality_metrics( sorting_analyzer_empty, metric_names=None, - qm_params=dict(isi_violation=dict(isi_threshold_ms=2)), + metric_params=dict(isi_violation=dict(isi_threshold_ms=2)), skip_pc_metrics=True, seed=2205, ) diff --git a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py index a5e6ded519..fc8ccb788b 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py @@ -6,6 +6,8 @@ from spikeinterface.preprocessing.basepreprocessor import BasePreprocessor, BasePreprocessorSegment from spikeinterface.preprocessing.filter import fix_dtype +from .motion_utils import ensure_time_bin_edges, ensure_time_bins + def correct_motion_on_peaks(peaks, peak_locations, motion, recording) -> np.ndarray: """ @@ -54,6 +56,7 @@ def interpolate_motion_on_traces( segment_index=None, channel_inds=None, interpolation_time_bin_centers_s=None, + interpolation_time_bin_edges_s=None, spatial_interpolation_method="kriging", spatial_interpolation_kwargs={}, dtype=None, @@ -61,7 +64,11 @@ def interpolate_motion_on_traces( """ Apply inverse motion with spatial interpolation on traces. - Traces can be full traces, but also waveforms snippets. + Traces can be full traces, but also waveforms snippets. Times used for looking up + displacements are controlled by interpolation_time_bin_edges_s or + interpolation_time_bin_centers_s, or fall back to the Motion object's time bins + by default; times in the recording outside these time bins use the closest edge + bin's displacement value during interpolation. Parameters ---------- @@ -80,6 +87,9 @@ def interpolate_motion_on_traces( interpolation_time_bin_centers_s : None or np.array Manually specify the time bins which the interpolation happens in for this segment. If None, these are the motion estimate's time bins. + interpolation_time_bin_edges_s : None or np.array + If present, interpolation chunks will be the time bins defined by these edges + rather than interpolation_time_bin_centers_s or the motion's bins. spatial_interpolation_method : "idw" | "kriging", default: "kriging" The spatial interpolation method used to interpolate the channel locations: * idw : Inverse Distance Weighing @@ -119,26 +129,33 @@ def interpolate_motion_on_traces( total_num_chans = channel_locations.shape[0] # -- determine the blocks of frames that will land in the same interpolation time bin - time_bins = interpolation_time_bin_centers_s - if time_bins is None: - time_bins = motion.temporal_bins_s[segment_index] - bin_s = time_bins[1] - time_bins[0] - bins_start = time_bins[0] - 0.5 * bin_s - # nearest bin center for each frame? - bin_inds = (times - bins_start) // bin_s - bin_inds = bin_inds.astype(int) + if interpolation_time_bin_centers_s is None and interpolation_time_bin_edges_s is None: + interpolation_time_bin_centers_s = motion.temporal_bins_s[segment_index] + interpolation_time_bin_edges_s = motion.temporal_bin_edges_s[segment_index] + else: + interpolation_time_bin_centers_s, interpolation_time_bin_edges_s = ensure_time_bins( + interpolation_time_bin_centers_s, interpolation_time_bin_edges_s + ) + + # bin the frame times according to the interpolation time bins. + # searchsorted(b, t, side="right") == i means that b[i-1] <= t < b[i] + # hence the -1. doing it with "left" is not as nice -- we want t==b[0] + # to lead to i=1 (rounding down). + interpolation_bin_inds = np.searchsorted(interpolation_time_bin_edges_s, times, side="right") - 1 + # the time bins may not cover the whole set of times in the recording, # so we need to clip these indices to the valid range - np.clip(bin_inds, 0, time_bins.size, out=bin_inds) + n_bins = interpolation_time_bin_edges_s.shape[0] - 1 + np.clip(interpolation_bin_inds, 0, n_bins - 1, out=interpolation_bin_inds) # -- what are the possibilities here anyway? - bins_here = np.arange(bin_inds[0], bin_inds[-1] + 1) + interpolation_bins_here = np.arange(interpolation_bin_inds[0], interpolation_bin_inds[-1] + 1) # inperpolation kernel will be the same per temporal bin interp_times = np.empty(total_num_chans) current_start_index = 0 - for bin_ind in bins_here: - bin_time = time_bins[bin_ind] + for interp_bin_ind in interpolation_bins_here: + bin_time = interpolation_time_bin_centers_s[interp_bin_ind] interp_times.fill(bin_time) channel_motions = motion.get_displacement_at_time_and_depth( interp_times, @@ -166,16 +183,17 @@ def interpolate_motion_on_traces( # ax.set_title(f"bin_ind {bin_ind} - {bin_time}s - {spatial_interpolation_method}") # plt.show() + # quick search logic to find frames corresponding to this interpolation bin in the recording # quickly find the end of this bin, which is also the start of the next next_start_index = current_start_index + np.searchsorted( - bin_inds[current_start_index:], bin_ind + 1, side="left" + interpolation_bin_inds[current_start_index:], interp_bin_ind + 1, side="left" ) - in_bin = slice(current_start_index, next_start_index) + frames_in_bin = slice(current_start_index, next_start_index) # here we use a simple np.matmul even if dirft_kernel can be super sparse. # because the speed for a sparse matmul is not so good when we disable multi threaad (due multi processing # in ChunkRecordingExecutor) - np.matmul(traces[in_bin], drift_kernel, out=traces_corrected[in_bin]) + np.matmul(traces[frames_in_bin], drift_kernel, out=traces_corrected[frames_in_bin]) current_start_index = next_start_index return traces_corrected @@ -297,6 +315,7 @@ def __init__( p=1, num_closest=3, interpolation_time_bin_centers_s=None, + interpolation_time_bin_edges_s=None, interpolation_time_bin_size_s=None, dtype=None, **spatial_interpolation_kwargs, @@ -363,9 +382,14 @@ def __init__( # handle manual interpolation_time_bin_centers_s # the case where interpolation_time_bin_size_s is set is handled per-segment below - if interpolation_time_bin_centers_s is None: + if interpolation_time_bin_centers_s is None and interpolation_time_bin_edges_s is None: if interpolation_time_bin_size_s is None: interpolation_time_bin_centers_s = motion.temporal_bins_s + interpolation_time_bin_edges_s = motion.temporal_bin_edges_s + else: + interpolation_time_bin_centers_s, interpolation_time_bin_edges_s = ensure_time_bins( + interpolation_time_bin_centers_s, interpolation_time_bin_edges_s + ) for segment_index, parent_segment in enumerate(recording._recording_segments): # finish the per-segment part of the time bin logic @@ -375,8 +399,13 @@ def __init__( t_start, t_end = parent_segment.sample_index_to_time(np.array([0, s_end])) halfbin = interpolation_time_bin_size_s / 2.0 segment_interpolation_time_bins_s = np.arange(t_start + halfbin, t_end, interpolation_time_bin_size_s) + segment_interpolation_time_bin_edges_s = np.arange( + t_start, t_end + halfbin, interpolation_time_bin_size_s + ) + assert segment_interpolation_time_bin_edges_s.shape == (segment_interpolation_time_bins_s.shape[0] + 1,) else: segment_interpolation_time_bins_s = interpolation_time_bin_centers_s[segment_index] + segment_interpolation_time_bin_edges_s = interpolation_time_bin_edges_s[segment_index] rec_segment = InterpolateMotionRecordingSegment( parent_segment, @@ -387,6 +416,7 @@ def __init__( channel_inds, segment_index, segment_interpolation_time_bins_s, + segment_interpolation_time_bin_edges_s, dtype=dtype_, ) self.add_recording_segment(rec_segment) @@ -420,6 +450,7 @@ def __init__( channel_inds, segment_index, interpolation_time_bin_centers_s, + interpolation_time_bin_edges_s, dtype="float32", ): BasePreprocessorSegment.__init__(self, parent_recording_segment) @@ -429,13 +460,11 @@ def __init__( self.channel_inds = channel_inds self.segment_index = segment_index self.interpolation_time_bin_centers_s = interpolation_time_bin_centers_s + self.interpolation_time_bin_edges_s = interpolation_time_bin_edges_s self.dtype = dtype self.motion = motion def get_traces(self, start_frame, end_frame, channel_indices): - if self.time_vector is not None: - raise NotImplementedError("InterpolateMotionRecording does not yet support recordings with time_vectors.") - if start_frame is None: start_frame = 0 if end_frame is None: @@ -453,7 +482,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): channel_inds=self.channel_inds, spatial_interpolation_method=self.spatial_interpolation_method, spatial_interpolation_kwargs=self.spatial_interpolation_kwargs, - interpolation_time_bin_centers_s=self.interpolation_time_bin_centers_s, + interpolation_time_bin_edges_s=self.interpolation_time_bin_edges_s, ) if channel_indices is not None: diff --git a/src/spikeinterface/sortingcomponents/motion/motion_utils.py b/src/spikeinterface/sortingcomponents/motion/motion_utils.py index 635624cca8..680d75f221 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_utils.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_utils.py @@ -1,5 +1,5 @@ -import warnings import json +import warnings from pathlib import Path import numpy as np @@ -54,6 +54,7 @@ def __init__(self, displacement, temporal_bins_s, spatial_bins_um, direction="y" self.direction = direction self.dim = ["x", "y", "z"].index(direction) self.check_properties() + self.temporal_bin_edges_s = [ensure_time_bin_edges(tbins) for tbins in self.temporal_bins_s] def check_properties(self): assert all(d.ndim == 2 for d in self.displacement) @@ -576,3 +577,40 @@ def make_3d_motion_histograms( motion_histograms = np.log2(1 + motion_histograms) return motion_histograms, temporal_bin_edges, spatial_bin_edges + + +def ensure_time_bins(time_bin_centers_s=None, time_bin_edges_s=None): + """Ensure that both bin edges and bin centers are present + + If either of the inputs are None but not both, the missing is reconstructed + from the present. Going from edges to centers is done by taking midpoints. + Going from centers to edges is done by taking midpoints and padding with the + left and rightmost centers. + + Parameters + ---------- + time_bin_centers_s : None or np.array + time_bin_edges_s : None or np.array + + Returns + ------- + time_bin_centers_s, time_bin_edges_s + """ + if time_bin_centers_s is None and time_bin_edges_s is None: + raise ValueError("Need at least one of time_bin_centers_s or time_bin_edges_s.") + + if time_bin_centers_s is None: + assert time_bin_edges_s.ndim == 1 and time_bin_edges_s.size >= 2 + time_bin_centers_s = 0.5 * (time_bin_edges_s[1:] + time_bin_edges_s[:-1]) + + if time_bin_edges_s is None: + time_bin_edges_s = np.empty(time_bin_centers_s.shape[0] + 1, dtype=time_bin_centers_s.dtype) + time_bin_edges_s[[0, -1]] = time_bin_centers_s[[0, -1]] + if time_bin_centers_s.size > 2: + time_bin_edges_s[1:-1] = 0.5 * (time_bin_centers_s[1:] + time_bin_centers_s[:-1]) + + return time_bin_centers_s, time_bin_edges_s + + +def ensure_time_bin_edges(time_bin_centers_s=None, time_bin_edges_s=None): + return ensure_time_bins(time_bin_centers_s, time_bin_edges_s)[1] diff --git a/src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py index e022f0cc6c..e4ba870325 100644 --- a/src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py @@ -1,16 +1,14 @@ -from pathlib import Path +import warnings import numpy as np -import pytest import spikeinterface.core as sc -from spikeinterface import download_dataset +from spikeinterface.sortingcomponents.motion import Motion from spikeinterface.sortingcomponents.motion.motion_interpolation import ( InterpolateMotionRecording, correct_motion_on_peaks, interpolate_motion, interpolate_motion_on_traces, ) -from spikeinterface.sortingcomponents.motion import Motion from spikeinterface.sortingcomponents.tests.common import make_dataset @@ -67,43 +65,45 @@ def test_interpolate_motion_on_traces(): times = rec.get_times()[0:30000] for method in ("kriging", "idw", "nearest"): - traces_corrected = interpolate_motion_on_traces( - traces, - times, - channel_locations, - motion, - channel_inds=None, - spatial_interpolation_method=method, - # spatial_interpolation_kwargs={}, - spatial_interpolation_kwargs={"force_extrapolate": True}, - ) - assert traces.shape == traces_corrected.shape - assert traces.dtype == traces_corrected.dtype + for interpolation_time_bin_centers_s in (None, np.linspace(*times[[0, -1]], num=3)): + traces_corrected = interpolate_motion_on_traces( + traces, + times, + channel_locations, + motion, + channel_inds=None, + spatial_interpolation_method=method, + interpolation_time_bin_centers_s=interpolation_time_bin_centers_s, + # spatial_interpolation_kwargs={}, + spatial_interpolation_kwargs={"force_extrapolate": True}, + ) + assert traces.shape == traces_corrected.shape + assert traces.dtype == traces_corrected.dtype def test_interpolation_simple(): # a recording where a 1 moves at 1 chan per second. 30 chans 10 frames. # there will be 9 chans of drift, so we add 9 chans of padding to the bottom - nt = nc0 = 10 # these need to be the same for this test - nc1 = nc0 + nc0 - 1 - traces = np.zeros((nt, nc1), dtype="float32") - traces[:, :nc0] = np.eye(nc0) + n_samples = num_chans_orig = 10 # these need to be the same for this test + num_chans_drifted = num_chans_orig + num_chans_orig - 1 + traces = np.zeros((n_samples, num_chans_drifted), dtype="float32") + traces[:, :num_chans_orig] = np.eye(num_chans_orig) rec = sc.NumpyRecording(traces, sampling_frequency=1) - rec.set_dummy_probe_from_locations(np.c_[np.zeros(nc1), np.arange(nc1)]) + rec.set_dummy_probe_from_locations(np.c_[np.zeros(num_chans_drifted), np.arange(num_chans_drifted)]) - true_motion = Motion(np.arange(nt)[:, None], 0.5 + np.arange(nt), np.zeros(1)) + true_motion = Motion(np.arange(n_samples)[:, None], 0.5 + np.arange(n_samples), np.zeros(1)) rec_corrected = interpolate_motion(rec, true_motion, spatial_interpolation_method="nearest") traces_corrected = rec_corrected.get_traces() - assert traces_corrected.shape == (nc0, nc0) - assert np.array_equal(traces_corrected[:, 0], np.ones(nt)) - assert np.array_equal(traces_corrected[:, 1:], np.zeros((nt, nc0 - 1))) + assert traces_corrected.shape == (num_chans_orig, num_chans_orig) + assert np.array_equal(traces_corrected[:, 0], np.ones(n_samples)) + assert np.array_equal(traces_corrected[:, 1:], np.zeros((n_samples, num_chans_orig - 1))) # let's try a new version where we interpolate too slowly rec_corrected = interpolate_motion( rec, true_motion, spatial_interpolation_method="nearest", num_closest=2, interpolation_time_bin_size_s=2 ) traces_corrected = rec_corrected.get_traces() - assert traces_corrected.shape == (nc0, nc0) + assert traces_corrected.shape == (num_chans_orig, num_chans_orig) # what happens with nearest here? # well... due to rounding towards the nearest even number, the motion (which at # these time bin centers is 0.5, 2.5, 4.5, ...) flips the signal's nearest @@ -115,6 +115,66 @@ def test_interpolation_simple(): assert np.all(traces_corrected[:, 2:] == 0) +def test_cross_band_interpolation(): + """Simple version of using LFP to interpolate AP data + + This also tests the time vector implementation in interpolation. + The idea is to have two recordings which are all 0s with a 1 that + moves from one channel to another after 3s. They're at different + sampling frequencies. motion estimation in one sampling frequency + applied to the other should still lead to perfect correction. + """ + from spikeinterface.sortingcomponents.motion import estimate_motion + + # sampling freqs and timing for AP and LFP recordings + fs_lfp = 50.0 + fs_ap = 300.0 + t_start = 10.0 + total_duration = 5.0 + num_samples_lfp = int(fs_lfp * total_duration) + num_samples_ap = int(fs_ap * total_duration) + t_switch = 3 + + # because interpolation uses bin centers logic, there will be a half + # bin offset at the change point in the AP recording. + halfbin_ap_lfp = int(0.5 * (fs_ap / fs_lfp)) + + # channel geometry + num_chans = 10 + geom = np.c_[np.zeros(num_chans), np.arange(num_chans)] + + # make an LFP recording which drifts a bit + traces_lfp = np.zeros((num_samples_lfp, num_chans)) + traces_lfp[: int(t_switch * fs_lfp), 5] = 1.0 + traces_lfp[int(t_switch * fs_lfp) :, 6] = 1.0 + rec_lfp = sc.NumpyRecording(traces_lfp, sampling_frequency=fs_lfp) + rec_lfp.set_dummy_probe_from_locations(geom) + + # same for AP + traces_ap = np.zeros((num_samples_ap, num_chans)) + traces_ap[: int(t_switch * fs_ap) - halfbin_ap_lfp, 5] = 1.0 + traces_ap[int(t_switch * fs_ap) - halfbin_ap_lfp :, 6] = 1.0 + rec_ap = sc.NumpyRecording(traces_ap, sampling_frequency=fs_ap) + rec_ap.set_dummy_probe_from_locations(geom) + + # set times for both, and silence the warning + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=UserWarning) + rec_lfp.set_times(t_start + np.arange(num_samples_lfp) / fs_lfp) + rec_ap.set_times(t_start + np.arange(num_samples_ap) / fs_ap) + + # estimate motion + motion = estimate_motion(rec_lfp, method="dredge_lfp", rigid=True) + + # nearest to keep it simple + rec_corrected = interpolate_motion(rec_ap, motion, spatial_interpolation_method="nearest", num_closest=2) + traces_corrected = rec_corrected.get_traces() + target = np.zeros((num_samples_ap, num_chans - 2)) + target[:, 4] = 1 + ii, jj = np.nonzero(traces_corrected) + assert np.array_equal(traces_corrected, target) + + def test_InterpolateMotionRecording(): rec, sorting = make_dataset() motion = make_fake_motion(rec) @@ -147,6 +207,7 @@ def test_InterpolateMotionRecording(): if __name__ == "__main__": # test_correct_motion_on_peaks() - # test_interpolate_motion_on_traces() - test_interpolation_simple() - test_InterpolateMotionRecording() + test_interpolate_motion_on_traces() + # test_interpolation_simple() + # test_InterpolateMotionRecording() + test_cross_band_interpolation()