From f96e2b79df22b699cfd380b126b4c977339f3ab1 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 13 Sep 2023 19:28:00 +0200 Subject: [PATCH 1/7] Extend and refactor waveform metrics --- src/spikeinterface/postprocessing/__init__.py | 1 - .../postprocessing/template_metrics.py | 583 ++++++++++++++++-- .../tests/test_template_metrics.py | 8 +- 3 files changed, 534 insertions(+), 58 deletions(-) diff --git a/src/spikeinterface/postprocessing/__init__.py b/src/spikeinterface/postprocessing/__init__.py index 223bda5e30..d7e1ffac01 100644 --- a/src/spikeinterface/postprocessing/__init__.py +++ b/src/spikeinterface/postprocessing/__init__.py @@ -10,7 +10,6 @@ from .template_metrics import ( TemplateMetricsCalculator, compute_template_metrics, - calculate_template_metrics, get_template_metric_names, ) diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 681f6f3e84..119f0dc53d 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -11,9 +11,24 @@ from ..core.waveform_extractor import BaseWaveformExtractorExtension import warnings +# DEBUG = True + +# if DEBUG: +# import matplotlib.pyplot as plt +# plt.ion() +# plt.show() + + +def get_1d_template_metric_names(): + return deepcopy(list(_1d_metric_name_to_func.keys())) + + +def get_2d_template_metric_names(): + return deepcopy(list(_2d_metric_name_to_func.keys())) + def get_template_metric_names(): - return deepcopy(list(_metric_name_to_func.keys())) + return get_1d_template_metric_names() + get_2d_template_metric_names() class TemplateMetricsCalculator(BaseWaveformExtractorExtension): @@ -26,20 +41,31 @@ class TemplateMetricsCalculator(BaseWaveformExtractorExtension): """ extension_name = "template_metrics" + min_channels_for_2d_warning = 10 def __init__(self, waveform_extractor): BaseWaveformExtractorExtension.__init__(self, waveform_extractor) - def _set_params(self, metric_names=None, peak_sign="neg", upsampling_factor=10, sparsity=None, window_slope_ms=0.7): + def _set_params( + self, + metric_names=None, + peak_sign="neg", + upsampling_factor=10, + sparsity=None, + functions_kwargs=None, + include_2d_metrics=False, + ): if metric_names is None: - metric_names = get_template_metric_names() - + metric_names = get_1d_template_metric_names() + if include_2d_metrics: + metric_names += get_2d_template_metric_names() + functions_kwargs = functions_kwargs or dict() params = dict( metric_names=[str(name) for name in metric_names], sparsity=sparsity, peak_sign=peak_sign, upsampling_factor=int(upsampling_factor), - window_slope_ms=float(window_slope_ms), + functions_kwargs=functions_kwargs, ) return params @@ -60,6 +86,9 @@ def _run(self): unit_ids = self.waveform_extractor.sorting.unit_ids sampling_frequency = self.waveform_extractor.sampling_frequency + metrics_1d = [m for m in metric_names if m in get_1d_template_metric_names()] + metrics_2d = [m for m in metric_names if m in get_2d_template_metric_names()] + if sparsity is None: extremum_channels_ids = get_template_extremum_channel( self.waveform_extractor, peak_sign=peak_sign, outputs="id" @@ -79,6 +108,8 @@ def _run(self): template_metrics = pd.DataFrame(index=multi_index, columns=metric_names) all_templates = self.waveform_extractor.get_all_templates() + channel_locations = self.waveform_extractor.get_channel_locations() + for unit_index, unit_id in enumerate(unit_ids): template_all_chans = all_templates[unit_index] chan_ids = np.array(extremum_channels_ids[unit_id]) @@ -87,6 +118,7 @@ def _run(self): chan_ind = self.waveform_extractor.channel_ids_to_indices(chan_ids) template = template_all_chans[:, chan_ind] + # compute 1d metrics for i, template_single in enumerate(template.T): if sparsity is None: index = unit_id @@ -100,15 +132,50 @@ def _run(self): template_upsampled = template_single sampling_frequency_up = sampling_frequency - for metric_name in metric_names: + trough_idx, peak_idx = get_trough_and_peak_idx(template_upsampled) + + for metric_name in metrics_1d: func = _metric_name_to_func[metric_name] value = func( template_upsampled, sampling_frequency=sampling_frequency_up, - window_ms=self._params["window_slope_ms"], + trough_idx=trough_idx, + peak_idx=peak_idx, + **self._params["functions_kwargs"], ) template_metrics.at[index, metric_name] = value + # compute metrics 2d + for metric_name in metrics_2d: + # retrieve template (with sparsity if waveform extractor is sparse) + template = self.waveform_extractor.get_template(unit_id=unit_id) + + if template.shape[1] < self.min_channels_for_2d_warning: + warnings.warn( + f"With less than {self.min_channels_for_2d_warning} channels, " + "2D metrics might not be reliable." + ) + if self.waveform_extractor.is_sparse(): + channel_locations_sparse = channel_locations[self.waveform_extractor.sparsity.mask[unit_index]] + else: + channel_locations_sparse = channel_locations + + if upsampling_factor > 1: + assert isinstance(upsampling_factor, (int, np.integer)), "'upsample' must be an integer" + template_upsampled = resample_poly(template, up=upsampling_factor, down=1, axis=0) + sampling_frequency_up = upsampling_factor * sampling_frequency + else: + template_upsampled = template + sampling_frequency_up = sampling_frequency + + func = _metric_name_to_func[metric_name] + value = func( + template_upsampled, + channel_locations=channel_locations_sparse, + sampling_frequency=sampling_frequency_up, + **self._params["functions_kwargs"], + ) + template_metrics.at[index, metric_name] = value self._extension_data["metrics"] = template_metrics def get_data(self): @@ -139,7 +206,17 @@ def compute_template_metrics( peak_sign="neg", upsampling_factor=10, sparsity=None, - window_slope_ms=0.7, + include_2d_metrics=False, + functions_kwargs=dict( + recovery_window_ms=0.7, + peak_relative_threshold=0.2, + peak_width_ms=0.2, + depth_direction="y", + min_channels_for_velocity=5, + min_r2_for_velocity=0.5, + exp_peak_function="ptp", + spread_threshold=0.2, + ), ): """ Compute template metrics including: @@ -148,6 +225,14 @@ def compute_template_metrics( * halfwidth * repolarization_slope * recovery_slope + * num_positive_peaks + * num_negative_peaks + + Optionally, the following 2d metrics can be computed (when include_2d_metrics=True): + * velocity_above + * velocity_below + * exp_decay + * spread Parameters ---------- @@ -157,34 +242,57 @@ def compute_template_metrics( Whether to load precomputed template metrics, if they already exist. metric_names : list, optional List of metrics to compute (see si.postprocessing.get_template_metric_names()), by default None - peak_sign : str, optional - "pos" | "neg", by default 'neg' - upsampling_factor : int, optional - Upsample factor, by default 10 - sparsity: dict or None + peak_sign : {"neg", "pos"}, default: "neg" + The peak sign + upsampling_factor : int, default: 10 + The upsampling factor to upsample the templates + sparsity: dict or None, default: None Default is sparsity=None and template metric is computed on extremum channel only. If given, the dictionary should contain a unit ids as keys and a channel id or a list of channel ids as values. For more generating a sparsity dict, see the postprocessing.compute_sparsity() function. - window_slope_ms: float - Window in ms after the positiv peak to compute slope, by default 0.7 + include_2d_metrics: bool, default: False + Whether to compute 2d metrics + functions_kwargs: dict + Additional arguments to pass to the metric functions. Including: + * recovery_window_ms: the window in ms after the peak to compute the recovery_slope, default: 0.7 + * peak_relative_threshold: the relative threshold to detect positive and negative peaks, default: 0.2 + * peak_width_ms: the width in samples to detect peaks, default: 0.2 + * depth_direction: the direction to compute velocity above and below, default: "y" + * min_channels_for_velocity: the minimum number of channels above or below to compute velocity, default: 5 + * min_r2_for_velocity: the minimum r2 to accept the velocity fit, default: 0.7 + * exp_peak_function: the function to use to compute the peak amplitude for the exp decay, default: "ptp" + * spread_threshold: the threshold to compute the spread, default: 0.2 Returns ------- - tempalte_metrics : pd.DataFrame + template_metrics : pd.DataFrame Dataframe with the computed template metrics. If 'sparsity' is None, the index is the unit_id. If 'sparsity' is given, the index is a multi-index (unit_id, channel_id) + + Notes + ----- + If any 2d metric is in the metric_names or include_2d_metrics is True, sparsity must be None, so that one metric + value will be computed per unit. """ if load_if_exists and waveform_extractor.is_extension(TemplateMetricsCalculator.extension_name): tmc = waveform_extractor.load_extension(TemplateMetricsCalculator.extension_name) else: tmc = TemplateMetricsCalculator(waveform_extractor) + # For 2D metrics, external sparsity must be None, so that one metric value will be computed per unit. + if include_2d_metrics or ( + metric_names is not None and any([m in get_2d_template_metric_names() for m in metric_names]) + ): + assert ( + sparsity is None + ), "If 2D metrics are computed, sparsity must be None, so that each unit will correspond to 1 row of the output dataframe." tmc.set_params( metric_names=metric_names, peak_sign=peak_sign, upsampling_factor=upsampling_factor, sparsity=sparsity, - window_slope_ms=window_slope_ms, + include_2d_metrics=include_2d_metrics, + functions_kwargs=functions_kwargs, ) tmc.run() @@ -197,7 +305,19 @@ def get_trough_and_peak_idx(template): """ Return the indices into the input template of the detected trough (minimum of template) and peak (maximum of template, after trough). - Assumes negative trough and positive peak + Assumes negative trough and positive peak. + + Parameters + ---------- + template: numpy.ndarray + The 1D template waveform + + Returns + ------- + trough_idx: int + The index of the trough + peak_idx: int + The index of the peak """ assert template.ndim == 1 trough_idx = np.argmin(template) @@ -205,41 +325,94 @@ def get_trough_and_peak_idx(template): return trough_idx, peak_idx -def get_peak_to_valley(template, **kwargs): +######################################################################################### +# 1D metrics +def get_peak_to_valley(template_single, trough_idx=None, peak_idx=None, **kwargs): """ - Time between trough and peak in s + Return the peak to valley duration in seconds of input waveforms. + + Parameters + ---------- + template_single: numpy.ndarray + The 1D template waveform + trough_idx: int, default: None + The index of the trough + peak_idx: int, default: None + The index of the peak + **kwargs: Required kwargs: + - sampling_frequency: the sampling frequency + + Returns + ------- + ptv: float + The peak to valley duration in seconds """ sampling_frequency = kwargs["sampling_frequency"] - trough_idx, peak_idx = get_trough_and_peak_idx(template) + if trough_idx is None or peak_idx is None: + trough_idx, peak_idx = get_trough_and_peak_idx(template_single) ptv = (peak_idx - trough_idx) / sampling_frequency return ptv -def get_peak_trough_ratio(template, **kwargs): +def get_peak_trough_ratio(template_single, trough_idx=None, peak_idx=None, **kwargs): """ - Ratio between peak heigth and trough depth + Return the peak to trough ratio of input waveforms. + + Parameters + ---------- + template_single: numpy.ndarray + The 1D template waveform + trough_idx: int, default: None + The index of the trough + peak_idx: int, default: None + The index of the peak + **kwargs: Required kwargs: + - sampling_frequency: the sampling frequency + + Returns + ------- + ptratio: float + The peak to trough ratio """ - trough_idx, peak_idx = get_trough_and_peak_idx(template) - ptratio = template[peak_idx] / template[trough_idx] + if trough_idx is None or peak_idx is None: + trough_idx, peak_idx = get_trough_and_peak_idx(template_single) + ptratio = template_single[peak_idx] / template_single[trough_idx] return ptratio -def get_half_width(template, **kwargs): +def get_half_width(template_single, trough_idx=None, peak_idx=None, **kwargs): """ - Width of waveform at its half of amplitude in s + Return the half width of input waveforms in seconds. + + Parameters + ---------- + template_single: numpy.ndarray + The 1D template waveform + trough_idx: int, default: None + The index of the trough + peak_idx: int, default: None + The index of the peak + **kwargs: Required kwargs: + - sampling_frequency: the sampling frequency + + Returns + ------- + hw: float + The half width in seconds """ - trough_idx, peak_idx = get_trough_and_peak_idx(template) + if trough_idx is None or peak_idx is None: + trough_idx, peak_idx = get_trough_and_peak_idx(template_single) sampling_frequency = kwargs["sampling_frequency"] if peak_idx == 0: return np.nan - trough_val = template[trough_idx] + trough_val = template_single[trough_idx] # threshold is half of peak heigth (assuming baseline is 0) threshold = 0.5 * trough_val - (cpre_idx,) = np.where(template[:trough_idx] < threshold) - (cpost_idx,) = np.where(template[trough_idx:] < threshold) + (cpre_idx,) = np.where(template_single[:trough_idx] < threshold) + (cpost_idx,) = np.where(template_single[trough_idx:] < threshold) if len(cpre_idx) == 0 or len(cpost_idx) == 0: hw = np.nan @@ -254,7 +427,7 @@ def get_half_width(template, **kwargs): return hw -def get_repolarization_slope(template, **kwargs): +def get_repolarization_slope(template_single, trough_idx=None, **kwargs): """ Return slope of repolarization period between trough and baseline @@ -264,17 +437,26 @@ def get_repolarization_slope(template, **kwargs): Optionally the function returns also the indices per waveform where the potential crosses baseline. - """ - trough_idx, peak_idx = get_trough_and_peak_idx(template) + Parameters + ---------- + template_single: numpy.ndarray + The 1D template waveform + trough_idx: int, default: None + The index of the trough + **kwargs: Required kwargs: + - sampling_frequency: the sampling frequency + """ + if trough_idx is None: + trough_idx = get_trough_and_peak_idx(template_single) sampling_frequency = kwargs["sampling_frequency"] - times = np.arange(template.shape[0]) / sampling_frequency + times = np.arange(template_single.shape[0]) / sampling_frequency if trough_idx == 0: return np.nan - (rtrn_idx,) = np.nonzero(template[trough_idx:] >= 0) + (rtrn_idx,) = np.nonzero(template_single[trough_idx:] >= 0) if len(rtrn_idx) == 0: return np.nan # first time after trough, where template is at baseline @@ -285,11 +467,11 @@ def get_repolarization_slope(template, **kwargs): import scipy.stats - res = scipy.stats.linregress(times[trough_idx:return_to_base_idx], template[trough_idx:return_to_base_idx]) + res = scipy.stats.linregress(times[trough_idx:return_to_base_idx], template_single[trough_idx:return_to_base_idx]) return res.slope -def get_recovery_slope(template, window_ms=0.7, **kwargs): +def get_recovery_slope(template_single, peak_idx=None, **kwargs): """ Return the recovery slope of input waveforms. After repolarization, the neuron hyperpolarizes untill it peaks. The recovery slope is the @@ -299,41 +481,332 @@ def get_recovery_slope(template, window_ms=0.7, **kwargs): Takes a numpy array of waveforms and returns an array with recovery slopes per waveform. + + Parameters + ---------- + template_single: numpy.ndarray + The 1D template waveform + peak_idx: int, default: None + The index of the peak + **kwargs: Required kwargs: + - sampling_frequency: the sampling frequency + - recovery_window_ms: the window in ms after the peak to compute the recovery_slope """ + import scipy.stats - trough_idx, peak_idx = get_trough_and_peak_idx(template) + assert "recovery_window_ms" in kwargs, "recovery_window_ms must be given as kwarg" + recovery_window_ms = kwargs["recovery_window_ms"] + if peak_idx is None: + _, peak_idx = get_trough_and_peak_idx(template_single) sampling_frequency = kwargs["sampling_frequency"] - times = np.arange(template.shape[0]) / sampling_frequency + times = np.arange(template_single.shape[0]) / sampling_frequency if peak_idx == 0: return np.nan - max_idx = int(peak_idx + ((window_ms / 1000) * sampling_frequency)) - max_idx = np.min([max_idx, template.shape[0]]) - - import scipy.stats + max_idx = int(peak_idx + ((recovery_window_ms / 1000) * sampling_frequency)) + max_idx = np.min([max_idx, template_single.shape[0]]) - res = scipy.stats.linregress(times[peak_idx:max_idx], template[peak_idx:max_idx]) + res = scipy.stats.linregress(times[peak_idx:max_idx], template_single[peak_idx:max_idx]) return res.slope -_metric_name_to_func = { +def get_num_positive_peaks(template_single, **kwargs): + """ + Count the number of positive peaks in the template. + + Parameters + ---------- + template_single: numpy.ndarray + The 1D template waveform + **kwargs: Required kwargs: + - peak_relative_threshold: the relative threshold to detect positive and negative peaks + - peak_width_ms: the width in samples to detect peaks + - sampling_frequency: the sampling frequency + """ + from scipy.signal import find_peaks + + assert "peak_relative_threshold" in kwargs, "peak_relative_threshold must be given as kwarg" + assert "peak_width_ms" in kwargs, "peak_width_ms must be given as kwarg" + peak_relative_threshold = kwargs["peak_relative_threshold"] + peak_width_ms = kwargs["peak_width_ms"] + max_value = np.max(np.abs(template_single)) + peak_width_samples = int(peak_width_ms / 1000 * kwargs["sampling_frequency"]) + + pos_peaks = find_peaks(template_single, height=peak_relative_threshold * max_value, width=peak_width_samples) + + return len(pos_peaks[0]) + + +def get_num_negative_peaks(template_single, **kwargs): + """ + Count the number of negative peaks in the template. + + Parameters + ---------- + template_single: numpy.ndarray + The 1D template waveform + **kwargs: Required kwargs: + - peak_relative_threshold: the relative threshold to detect positive and negative peaks + - peak_width_ms: the width in samples to detect peaks + - sampling_frequency: the sampling frequency + """ + from scipy.signal import find_peaks + + assert "peak_relative_threshold" in kwargs, "peak_relative_threshold must be given as kwarg" + assert "peak_width_ms" in kwargs, "peak_width_ms must be given as kwarg" + peak_relative_threshold = kwargs["peak_relative_threshold"] + peak_width_ms = kwargs["peak_width_ms"] + max_value = np.max(np.abs(template_single)) + peak_width_samples = int(peak_width_ms / 1000 * kwargs["sampling_frequency"]) + + neg_peaks = find_peaks(-template_single, height=peak_relative_threshold * max_value, width=peak_width_samples) + + return len(neg_peaks[0]) + + +_1d_metric_name_to_func = { "peak_to_valley": get_peak_to_valley, "peak_trough_ratio": get_peak_trough_ratio, "half_width": get_half_width, "repolarization_slope": get_repolarization_slope, "recovery_slope": get_recovery_slope, + "num_positive_peaks": get_num_positive_peaks, + "num_negative_peaks": get_num_negative_peaks, } -# back-compatibility -def calculate_template_metrics(*args, **kwargs): - warnings.warn( - "The 'calculate_template_metrics' function is deprecated. " "Use 'compute_template_metrics' instead", - DeprecationWarning, - stacklevel=2, - ) - return compute_template_metrics(*args, **kwargs) +######################################################################################### +# 2D metrics + + +def fit_velocity(peak_times, channel_dist): + # from scipy.stats import linregress + # slope, intercept, _, _, _ = linregress(peak_times, channel_dist) + + from sklearn.linear_model import TheilSenRegressor + + theil = TheilSenRegressor() + theil.fit(peak_times.reshape(-1, 1), channel_dist) + slope = theil.coef_[0] + intercept = theil.intercept_ + score = theil.score(peak_times.reshape(-1, 1), channel_dist) + return slope, intercept, score + + +def get_velocity_above(template, channel_locations, **kwargs): + """ + Compute the velocity above the max channel of the template. + + Parameters + ---------- + template: numpy.ndarray + The template waveform (num_samples, num_channels) + channel_locations: numpy.ndarray + The channel locations (num_channels, 2) + **kwargs: Required kwargs: + - depth_direction: the direction to compute velocity above and below ("x", "y", or "z") + - min_channels_for_velocity: the minimum number of channels above or below to compute velocity + - min_r2_for_velocity: the minimum r2 to accept the velocity fit + - sampling_frequency: the sampling frequency + """ + assert "depth_direction" in kwargs, "depth_direction must be given as kwarg" + assert "min_channels_for_velocity" in kwargs, "min_channels_for_velocity must be given as kwarg" + assert "min_r2_for_velocity" in kwargs, "min_r2_for_velocity must be given as kwarg" + + depth_direction = kwargs["depth_direction"] + min_channels_for_velocity = kwargs["min_channels_for_velocity"] + min_r2_for_velocity = kwargs["min_r2_for_velocity"] + + direction_index = ["x", "y", "z"].index(depth_direction) + sampling_frequency = kwargs["sampling_frequency"] + + # find location of max channel + max_sample_idx, max_channel_idx = np.unravel_index(np.argmin(template), template.shape) + max_channel_location = channel_locations[max_channel_idx] + + channels_above = channel_locations[:, direction_index] >= max_channel_location[direction_index] + + # we only consider samples forward in time with respect to the max channel + template_above = template[max_sample_idx:, channels_above] + channel_locations_above = channel_locations[channels_above] + + peak_times_ms_above = np.argmin(template_above, 0) / sampling_frequency * 1000 + distances_um_above = np.array([np.linalg.norm(cl - max_channel_location) for cl in channel_locations_above]) + velocity_above, intercept, score = fit_velocity(peak_times_ms_above, distances_um_above) + + # if DEBUG: + # fig, ax = plt.subplots() + # ax.plot(peak_times_ms_above, distances_um_above, "o") + # x = np.linspace(peak_times_ms_above.min(), peak_times_ms_above.max(), 20) + # ax.plot(x, intercept + x * velocity_above) + # ax.set_xlabel("Peak time (ms)") + # ax.set_ylabel("Distance from max channel (um)") + # ax.set_title(f"Velocity above: {velocity_above:.2f} um/ms") + + if np.sum(channels_above) < min_channels_for_velocity: + # if DEBUG: + # ax.set_title("NaN velocity - not enough channels") + return np.nan + + if score < min_r2_for_velocity: + # if DEBUG: + # ax.set_title(f"NaN velocity - R2 is too low: {score:.2f}") + return np.nan + return velocity_above + + +def get_velocity_below(template, channel_locations, **kwargs): + """ + Compute the velocity below the max channel of the template. + + Parameters + ---------- + template: numpy.ndarray + The template waveform (num_samples, num_channels) + channel_locations: numpy.ndarray + The channel locations (num_channels, 2) + **kwargs: Required kwargs: + - depth_direction: the direction to compute velocity above and below ("x", "y", or "z") + - min_channels_for_velocity: the minimum number of channels above or below to compute velocity + - min_r2_for_velocity: the minimum r2 to accept the velocity fit + - sampling_frequency: the sampling frequency + """ + assert "depth_direction" in kwargs, "depth_direction must be given as kwarg" + assert "min_channels_for_velocity" in kwargs, "min_channels_for_velocity must be given as kwarg" + assert "min_r2_for_velocity" in kwargs, "min_r2_for_velocity must be given as kwarg" + direction = kwargs["depth_direction"] + min_channels_for_velocity = kwargs["min_channels_for_velocity"] + min_r2_for_velocity = kwargs["min_r2_for_velocity"] + direction_index = ["x", "y", "z"].index(direction) + + # find location of max channel + max_sample_idx, max_channel_idx = np.unravel_index(np.argmin(template), template.shape) + max_channel_location = channel_locations[max_channel_idx] + sampling_frequency = kwargs["sampling_frequency"] + + channels_below = channel_locations[:, direction_index] <= max_channel_location[direction_index] + + # we only consider samples forward in time with respect to the max channel + template_below = template[max_sample_idx:, channels_below] + channel_locations_below = channel_locations[channels_below] + + peak_times_ms_below = np.argmin(template_below, 0) / sampling_frequency * 1000 + distances_um_below = np.array([np.linalg.norm(cl - max_channel_location) for cl in channel_locations_below]) + velocity_below, intercept, score = fit_velocity(peak_times_ms_below, distances_um_below) + + # if DEBUG: + # fig, ax = plt.subplots() + # ax.plot(peak_times_ms_below, distances_um_below, "o") + # x = np.linspace(peak_times_ms_below.min(), peak_times_ms_below.max(), 20) + # ax.plot(x, intercept + x * velocity_below) + # ax.set_xlabel("Peak time (ms)") + # ax.set_ylabel("Distance from max channel (um)") + # ax.set_title(f"Velocity below: {np.round(velocity_below, 3)} um/ms") + + if np.sum(channels_below) < min_channels_for_velocity: + # if DEBUG: + # ax.set_title("NaN velocity - not enough channels") + return np.nan + + if score < min_r2_for_velocity: + # if DEBUG: + # ax.set_title(f"NaN velocity - R2 is too low: {np.round(score, 3)}") + return np.nan + + return velocity_below + + +def get_exp_decay(template, channel_locations, **kwargs): + """ + Compute the exponential decay of the template amplitude over distance. + + Parameters + ---------- + template: numpy.ndarray + The template waveform (num_samples, num_channels) + channel_locations: numpy.ndarray + The channel locations (num_channels, 2) + **kwargs: Required kwargs: + - exp_peak_function: the function to use to compute the peak amplitude for the exp decay ("ptp" or "min") + """ + from scipy.optimize import curve_fit + + def exp_decay(x, a, b, c): + return a * np.exp(-b * x) + c + + assert "exp_peak_function" in kwargs, "exp_peak_function must be given as kwarg" + exp_peak_function = kwargs["exp_peak_function"] + # exp decay fit + if exp_peak_function == "ptp": + fun = np.ptp + elif exp_peak_function == "min": + fun = np.min + peak_amplitudes = np.abs(fun(template, axis=0)) + max_channel_location = channel_locations[np.argmax(peak_amplitudes)] + channel_distances = np.array([np.linalg.norm(cl - max_channel_location) for cl in channel_locations]) + distances_sort_indices = np.argsort(channel_distances) + channel_distances_sorted = channel_distances[distances_sort_indices] + peak_amplitudes_sorted = peak_amplitudes[distances_sort_indices] + try: + popt, _ = curve_fit(exp_decay, channel_distances_sorted, peak_amplitudes_sorted) + exp_decay_value = popt[1] + # if DEBUG: + # fig, ax = plt.subplots() + # ax.plot(channel_distances_sorted, peak_amplitudes_sorted, "o") + # x = np.arange(channel_distances_sorted.min(), channel_distances_sorted.max()) + # ax.plot(x, exp_decay(x, *popt)) + # ax.set_xlabel("Distance from max channel (um)") + # ax.set_ylabel("Peak amplitude") + # ax.set_title(f"Exp decay: {np.round(exp_decay_value, 3)}") + except: + exp_decay_value = np.nan + return exp_decay_value + + +def get_spread(template, channel_locations, **kwargs): + """ + Compute the spread of the template amplitude over distance. + Parameters + ---------- + template: numpy.ndarray + The template waveform (num_samples, num_channels) + channel_locations: numpy.ndarray + The channel locations (num_channels, 2) + **kwargs: Required kwargs: + - depth_direction: the direction to compute velocity above and below ("x", "y", or "z") + - spread_threshold: the threshold to compute the spread + """ + assert "depth_direction" in kwargs, "depth_direction must be given as kwarg" + depth_direction = kwargs["depth_direction"] + assert "spread_threshold" in kwargs, "spread_threshold must be given as kwarg" + spread_threshold = kwargs["spread_threshold"] + + direction_index = ["x", "y", "z"].index(depth_direction) + MM = np.ptp(template, 0) + MM = MM / np.max(MM) + channel_locations_above_theshold = channel_locations[MM > spread_threshold] + channel_depth_above_theshold = channel_locations_above_theshold[:, direction_index] + spread = np.ptp(channel_depth_above_theshold) + + # if DEBUG: + # fig, ax = plt.subplots() + # channel_depths = channel_locations[:, direction_index] + # sort_indices = np.argsort(channel_depths) + # ax.plot(channel_depths[sort_indices], MM[sort_indices], "o-") + # ax.axhline(spread_threshold, ls="--", color="r") + # ax.set_xlabel("Depth (um)") + # ax.set_ylabel("Amplitude") + # ax.set_title(f"Spread: {np.round(spread, 3)} um") + return spread + + +_2d_metric_name_to_func = { + "velocity_above": get_velocity_above, + "velocity_below": get_velocity_below, + "exp_decay": get_exp_decay, + "spread": get_spread, +} -calculate_template_metrics.__doc__ = compute_template_metrics.__doc__ +_metric_name_to_func = {**_1d_metric_name_to_func, **_2d_metric_name_to_func} diff --git a/src/spikeinterface/postprocessing/tests/test_template_metrics.py b/src/spikeinterface/postprocessing/tests/test_template_metrics.py index 9895e2ec4c..5dcff3ffba 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_metrics.py +++ b/src/spikeinterface/postprocessing/tests/test_template_metrics.py @@ -17,9 +17,13 @@ def test_sparse_metrics(self): tm_sparse = self.extension_class.get_extension_function()(self.we1, sparsity=self.sparsity1) print(tm_sparse) + def test_2d_metrics(self): + tm_2d = self.extension_class.get_extension_function()(self.we1, include_2d_metrics=True) + print(tm_2d) + if __name__ == "__main__": test = TemplateMetricsExtensionTest() test.setUp() - test.test_extension() - test.test_sparse_metrics() + # test.test_extension() + test.test_2d_metrics() From 226ad852e25596c0f6072f48a72e2e3d4a84afab Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 22 Sep 2023 12:32:33 +0200 Subject: [PATCH 2/7] Update tests --- .../postprocessing/tests/test_template_metrics.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/postprocessing/tests/test_template_metrics.py b/src/spikeinterface/postprocessing/tests/test_template_metrics.py index 5dcff3ffba..a27ccc77f8 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_metrics.py +++ b/src/spikeinterface/postprocessing/tests/test_template_metrics.py @@ -17,13 +17,13 @@ def test_sparse_metrics(self): tm_sparse = self.extension_class.get_extension_function()(self.we1, sparsity=self.sparsity1) print(tm_sparse) - def test_2d_metrics(self): - tm_2d = self.extension_class.get_extension_function()(self.we1, include_2d_metrics=True) - print(tm_2d) + def test_multi_channel_metrics(self): + tm_multi = self.extension_class.get_extension_function()(self.we1, include_multi_channel_metrics=True) + print(tm_multi) if __name__ == "__main__": test = TemplateMetricsExtensionTest() test.setUp() # test.test_extension() - test.test_2d_metrics() + test.test_multi_channel_metrics() From 00f91eb99de0052daf6ae67a47026e1490bcd278 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 25 Sep 2023 12:02:51 +0200 Subject: [PATCH 3/7] Do not save/overwrite params in read-only mode --- src/spikeinterface/core/waveform_extractor.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index 6881ab3ec5..9f85603e51 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -1988,6 +1988,9 @@ def set_params(self, **params): params = self._set_params(**params) self._params = params + if self.waveform_extractor.is_read_only(): + return + params_to_save = params.copy() if "sparsity" in params and params["sparsity"] is not None: assert isinstance( From 7ba84ad7d9913b4846d9d6903a13a1f441156647 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 28 Sep 2023 12:25:29 +0200 Subject: [PATCH 4/7] updates --- src/spikeinterface/core/waveform_extractor.py | 24 +- .../postprocessing/template_metrics.py | 343 +++++++++++------- 2 files changed, 239 insertions(+), 128 deletions(-) diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index 9f85603e51..79456a40ce 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -811,14 +811,30 @@ def select_units(self, unit_ids, new_folder=None, use_relative_path: bool = Fals sparsity = ChannelSparsity(mask, unit_ids, self.channel_ids) else: sparsity = None - we = WaveformExtractor.create(self.recording, sorting, folder=None, mode="memory", sparsity=sparsity) - we.set_params(**self._params) + if self.has_recording(): + we = WaveformExtractor.create(self.recording, sorting, folder=None, mode="memory", sparsity=sparsity) + else: + we = WaveformExtractor( + recording=None, + sorting=sorting, + folder=None, + sparsity=sparsity, + rec_attributes=self._rec_attributes, + allow_unfiltered=True, + ) + we._params = self._params # copy memory objects if self.has_waveforms(): we._memory_objects = {"wfs_arrays": {}, "sampled_indices": {}} for unit_id in unit_ids: - we._memory_objects["wfs_arrays"][unit_id] = self._memory_objects["wfs_arrays"][unit_id] - we._memory_objects["sampled_indices"][unit_id] = self._memory_objects["sampled_indices"][unit_id] + if self.format == "memory": + we._memory_objects["wfs_arrays"][unit_id] = self._memory_objects["wfs_arrays"][unit_id] + we._memory_objects["sampled_indices"][unit_id] = self._memory_objects["sampled_indices"][ + unit_id + ] + else: + we._memory_objects["wfs_arrays"][unit_id] = self.get_waveforms(unit_id) + we._memory_objects["sampled_indices"][unit_id] = self.get_sampled_indices(unit_id) # finally select extensions data for ext_name in self.get_available_extension_names(): diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index ea44dea9cb..090dae4567 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -11,12 +11,9 @@ from ..core.waveform_extractor import BaseWaveformExtractorExtension import warnings -# DEBUG = True -# if DEBUG: -# import matplotlib.pyplot as plt -# plt.ion() -# plt.show() +global DEBUG +DEBUG = False def get_single_channel_template_metric_names(): @@ -52,20 +49,20 @@ def _set_params( peak_sign="neg", upsampling_factor=10, sparsity=None, - functions_kwargs=None, + metrics_kwargs=None, include_multi_channel_metrics=False, ): if metric_names is None: metric_names = get_single_channel_template_metric_names() if include_multi_channel_metrics: metric_names += get_multi_channel_template_metric_names() - functions_kwargs = functions_kwargs or dict() + metrics_kwargs = metrics_kwargs or dict() params = dict( metric_names=[str(name) for name in metric_names], sparsity=sparsity, peak_sign=peak_sign, upsampling_factor=int(upsampling_factor), - functions_kwargs=functions_kwargs, + metrics_kwargs=metrics_kwargs, ) return params @@ -141,7 +138,7 @@ def _run(self): sampling_frequency=sampling_frequency_up, trough_idx=trough_idx, peak_idx=peak_idx, - **self._params["functions_kwargs"], + **self._params["metrics_kwargs"], ) template_metrics.at[index, metric_name] = value @@ -173,7 +170,7 @@ def _run(self): template_upsampled, channel_locations=channel_locations_sparse, sampling_frequency=sampling_frequency_up, - **self._params["functions_kwargs"], + **self._params["metrics_kwargs"], ) template_metrics.at[index, metric_name] = value self._extension_data["metrics"] = template_metrics @@ -199,6 +196,21 @@ def get_extension_function(): WaveformExtractor.register_extension(TemplateMetricsCalculator) +_default_function_kwargs = dict( + recovery_window_ms=0.7, + peak_relative_threshold=0.2, + peak_width_ms=0.1, + depth_direction="y", + min_channels_for_velocity=5, + min_r2_velocity=0.5, + exp_peak_function="ptp", + min_r2_exp_decay=0.5, + spread_threshold=0.2, + spread_smooth_um=20, + same_x=False, +) + + def compute_template_metrics( waveform_extractor, load_if_exists=False, @@ -207,16 +219,8 @@ def compute_template_metrics( upsampling_factor=10, sparsity=None, include_multi_channel_metrics=False, - functions_kwargs=dict( - recovery_window_ms=0.7, - peak_relative_threshold=0.2, - peak_width_ms=0.2, - depth_direction="y", - min_channels_for_velocity=5, - min_r2_for_velocity=0.5, - exp_peak_function="ptp", - spread_threshold=0.2, - ), + metrics_kwargs=None, + debug_plots=False, ): """ Compute template metrics including: @@ -252,14 +256,14 @@ def compute_template_metrics( For more generating a sparsity dict, see the postprocessing.compute_sparsity() function. include_multi_channel_metrics: bool, default: False Whether to compute multi-channel metrics - functions_kwargs: dict + metrics_kwargs: dict Additional arguments to pass to the metric functions. Including: * recovery_window_ms: the window in ms after the peak to compute the recovery_slope, default: 0.7 * peak_relative_threshold: the relative threshold to detect positive and negative peaks, default: 0.2 * peak_width_ms: the width in samples to detect peaks, default: 0.2 * depth_direction: the direction to compute velocity above and below, default: "y" * min_channels_for_velocity: the minimum number of channels above or below to compute velocity, default: 5 - * min_r2_for_velocity: the minimum r2 to accept the velocity fit, default: 0.7 + * min_r2_velocity: the minimum r2 to accept the velocity fit, default: 0.7 * exp_peak_function: the function to use to compute the peak amplitude for the exp decay, default: "ptp" * spread_threshold: the threshold to compute the spread, default: 0.2 @@ -275,6 +279,9 @@ def compute_template_metrics( If any multi-channel metric is in the metric_names or include_multi_channel_metrics is True, sparsity must be None, so that one metric value will be computed per unit. """ + if debug_plots: + global DEBUG + DEBUG = True if load_if_exists and waveform_extractor.is_extension(TemplateMetricsCalculator.extension_name): tmc = waveform_extractor.load_extension(TemplateMetricsCalculator.extension_name) else: @@ -287,13 +294,19 @@ def compute_template_metrics( "If multi-channel metrics are computed, sparsity must be None, " "so that each unit will correspond to 1 row of the output dataframe." ) + default_kwargs = _default_function_kwargs.copy() + if metrics_kwargs is None: + metrics_kwargs = default_kwargs + else: + default_kwargs.update(metrics_kwargs) + metrics_kwargs = default_kwargs tmc.set_params( metric_names=metric_names, peak_sign=peak_sign, upsampling_factor=upsampling_factor, sparsity=sparsity, include_multi_channel_metrics=include_multi_channel_metrics, - functions_kwargs=functions_kwargs, + metrics_kwargs=metrics_kwargs, ) tmc.run() @@ -328,7 +341,7 @@ def get_trough_and_peak_idx(template): ######################################################################################### # Single-channel metrics -def get_peak_to_valley(template_single, trough_idx=None, peak_idx=None, **kwargs): +def get_peak_to_valley(template_single, sampling_frequency, trough_idx=None, peak_idx=None, **kwargs): """ Return the peak to valley duration in seconds of input waveforms. @@ -340,22 +353,19 @@ def get_peak_to_valley(template_single, trough_idx=None, peak_idx=None, **kwargs The index of the trough peak_idx: int, default: None The index of the peak - **kwargs: Required kwargs: - - sampling_frequency: the sampling frequency Returns ------- ptv: float The peak to valley duration in seconds """ - sampling_frequency = kwargs["sampling_frequency"] if trough_idx is None or peak_idx is None: trough_idx, peak_idx = get_trough_and_peak_idx(template_single) ptv = (peak_idx - trough_idx) / sampling_frequency return ptv -def get_peak_trough_ratio(template_single, trough_idx=None, peak_idx=None, **kwargs): +def get_peak_trough_ratio(template_single, sampling_frequency=None, trough_idx=None, peak_idx=None, **kwargs): """ Return the peak to trough ratio of input waveforms. @@ -367,8 +377,6 @@ def get_peak_trough_ratio(template_single, trough_idx=None, peak_idx=None, **kwa The index of the trough peak_idx: int, default: None The index of the peak - **kwargs: Required kwargs: - - sampling_frequency: the sampling frequency Returns ------- @@ -381,7 +389,7 @@ def get_peak_trough_ratio(template_single, trough_idx=None, peak_idx=None, **kwa return ptratio -def get_half_width(template_single, trough_idx=None, peak_idx=None, **kwargs): +def get_half_width(template_single, sampling_frequency, trough_idx=None, peak_idx=None, **kwargs): """ Return the half width of input waveforms in seconds. @@ -393,8 +401,6 @@ def get_half_width(template_single, trough_idx=None, peak_idx=None, **kwargs): The index of the trough peak_idx: int, default: None The index of the peak - **kwargs: Required kwargs: - - sampling_frequency: the sampling frequency Returns ------- @@ -403,7 +409,6 @@ def get_half_width(template_single, trough_idx=None, peak_idx=None, **kwargs): """ if trough_idx is None or peak_idx is None: trough_idx, peak_idx = get_trough_and_peak_idx(template_single) - sampling_frequency = kwargs["sampling_frequency"] if peak_idx == 0: return np.nan @@ -428,7 +433,7 @@ def get_half_width(template_single, trough_idx=None, peak_idx=None, **kwargs): return hw -def get_repolarization_slope(template_single, trough_idx=None, **kwargs): +def get_repolarization_slope(template_single, sampling_frequency, trough_idx=None, **kwargs): """ Return slope of repolarization period between trough and baseline @@ -445,12 +450,9 @@ def get_repolarization_slope(template_single, trough_idx=None, **kwargs): The 1D template waveform trough_idx: int, default: None The index of the trough - **kwargs: Required kwargs: - - sampling_frequency: the sampling frequency """ if trough_idx is None: trough_idx = get_trough_and_peak_idx(template_single) - sampling_frequency = kwargs["sampling_frequency"] times = np.arange(template_single.shape[0]) / sampling_frequency @@ -472,7 +474,7 @@ def get_repolarization_slope(template_single, trough_idx=None, **kwargs): return res.slope -def get_recovery_slope(template_single, peak_idx=None, **kwargs): +def get_recovery_slope(template_single, sampling_frequency, peak_idx=None, **kwargs): """ Return the recovery slope of input waveforms. After repolarization, the neuron hyperpolarizes untill it peaks. The recovery slope is the @@ -490,7 +492,6 @@ def get_recovery_slope(template_single, peak_idx=None, **kwargs): peak_idx: int, default: None The index of the peak **kwargs: Required kwargs: - - sampling_frequency: the sampling frequency - recovery_window_ms: the window in ms after the peak to compute the recovery_slope """ import scipy.stats @@ -499,7 +500,6 @@ def get_recovery_slope(template_single, peak_idx=None, **kwargs): recovery_window_ms = kwargs["recovery_window_ms"] if peak_idx is None: _, peak_idx = get_trough_and_peak_idx(template_single) - sampling_frequency = kwargs["sampling_frequency"] times = np.arange(template_single.shape[0]) / sampling_frequency @@ -512,7 +512,7 @@ def get_recovery_slope(template_single, peak_idx=None, **kwargs): return res.slope -def get_num_positive_peaks(template_single, **kwargs): +def get_num_positive_peaks(template_single, sampling_frequency, **kwargs): """ Count the number of positive peaks in the template. @@ -523,7 +523,6 @@ def get_num_positive_peaks(template_single, **kwargs): **kwargs: Required kwargs: - peak_relative_threshold: the relative threshold to detect positive and negative peaks - peak_width_ms: the width in samples to detect peaks - - sampling_frequency: the sampling frequency """ from scipy.signal import find_peaks @@ -532,14 +531,14 @@ def get_num_positive_peaks(template_single, **kwargs): peak_relative_threshold = kwargs["peak_relative_threshold"] peak_width_ms = kwargs["peak_width_ms"] max_value = np.max(np.abs(template_single)) - peak_width_samples = int(peak_width_ms / 1000 * kwargs["sampling_frequency"]) + peak_width_samples = int(peak_width_ms / 1000 * sampling_frequency) pos_peaks = find_peaks(template_single, height=peak_relative_threshold * max_value, width=peak_width_samples) return len(pos_peaks[0]) -def get_num_negative_peaks(template_single, **kwargs): +def get_num_negative_peaks(template_single, sampling_frequency, **kwargs): """ Count the number of negative peaks in the template. @@ -550,7 +549,6 @@ def get_num_negative_peaks(template_single, **kwargs): **kwargs: Required kwargs: - peak_relative_threshold: the relative threshold to detect positive and negative peaks - peak_width_ms: the width in samples to detect peaks - - sampling_frequency: the sampling frequency """ from scipy.signal import find_peaks @@ -559,7 +557,7 @@ def get_num_negative_peaks(template_single, **kwargs): peak_relative_threshold = kwargs["peak_relative_threshold"] peak_width_ms = kwargs["peak_width_ms"] max_value = np.max(np.abs(template_single)) - peak_width_samples = int(peak_width_ms / 1000 * kwargs["sampling_frequency"]) + peak_width_samples = int(peak_width_ms / 1000 * sampling_frequency) neg_peaks = find_peaks(-template_single, height=peak_relative_threshold * max_value, width=peak_width_samples) @@ -581,6 +579,20 @@ def get_num_negative_peaks(template_single, **kwargs): # Multi-channel metrics +def transform_same_x(template, channel_locations): + max_channel_x = channel_locations[np.argmax(np.ptp(template, axis=0)), 0] + same_x_mask = channel_locations[:, 0] == max_channel_x + channel_locations_same_x = channel_locations[same_x_mask] + template_same_x = template[:, same_x_mask] + return template_same_x, channel_locations_same_x + + +def sort_template_and_locations(template, channel_locations, depth_direction="y"): + direction_index = ["x", "y", "z"].index(depth_direction) + sort_indices = np.argsort(channel_locations[:, direction_index]) + return template[:, sort_indices], channel_locations[sort_indices, :] + + def fit_velocity(peak_times, channel_dist): # from scipy.stats import linregress # slope, intercept, _, _, _ = linregress(peak_times, channel_dist) @@ -595,7 +607,7 @@ def fit_velocity(peak_times, channel_dist): return slope, intercept, score -def get_velocity_above(template, channel_locations, **kwargs): +def get_velocity_above(template, channel_locations, sampling_frequency, **kwargs): """ Compute the velocity above the max channel of the template. @@ -608,56 +620,70 @@ def get_velocity_above(template, channel_locations, **kwargs): **kwargs: Required kwargs: - depth_direction: the direction to compute velocity above and below ("x", "y", or "z") - min_channels_for_velocity: the minimum number of channels above or below to compute velocity - - min_r2_for_velocity: the minimum r2 to accept the velocity fit - - sampling_frequency: the sampling frequency + - min_r2_velocity: the minimum r2 to accept the velocity fit """ assert "depth_direction" in kwargs, "depth_direction must be given as kwarg" assert "min_channels_for_velocity" in kwargs, "min_channels_for_velocity must be given as kwarg" - assert "min_r2_for_velocity" in kwargs, "min_r2_for_velocity must be given as kwarg" + assert "min_r2_velocity" in kwargs, "min_r2_velocity must be given as kwarg" + assert "same_x" in kwargs, "same_x must be given as kwarg" depth_direction = kwargs["depth_direction"] min_channels_for_velocity = kwargs["min_channels_for_velocity"] - min_r2_for_velocity = kwargs["min_r2_for_velocity"] + min_r2_velocity = kwargs["min_r2_velocity"] + same_x = kwargs["same_x"] direction_index = ["x", "y", "z"].index(depth_direction) - sampling_frequency = kwargs["sampling_frequency"] + template, channel_locations = sort_template_and_locations(template, channel_locations, depth_direction) + + if same_x: + template, channel_locations = transform_same_x(template, channel_locations) # find location of max channel max_sample_idx, max_channel_idx = np.unravel_index(np.argmin(template), template.shape) + max_peak_time = max_sample_idx / sampling_frequency * 1000 max_channel_location = channel_locations[max_channel_idx] channels_above = channel_locations[:, direction_index] >= max_channel_location[direction_index] # we only consider samples forward in time with respect to the max channel - template_above = template[max_sample_idx:, channels_above] + # TODO: not sure + # template_above = template[max_sample_idx:, channels_above] + template_above = template[:, channels_above] channel_locations_above = channel_locations[channels_above] - peak_times_ms_above = np.argmin(template_above, 0) / sampling_frequency * 1000 + peak_times_ms_above = np.argmin(template_above, 0) / sampling_frequency * 1000 - max_peak_time distances_um_above = np.array([np.linalg.norm(cl - max_channel_location) for cl in channel_locations_above]) velocity_above, intercept, score = fit_velocity(peak_times_ms_above, distances_um_above) - # if DEBUG: - # fig, ax = plt.subplots() - # ax.plot(peak_times_ms_above, distances_um_above, "o") - # x = np.linspace(peak_times_ms_above.min(), peak_times_ms_above.max(), 20) - # ax.plot(x, intercept + x * velocity_above) - # ax.set_xlabel("Peak time (ms)") - # ax.set_ylabel("Distance from max channel (um)") - # ax.set_title(f"Velocity above: {velocity_above:.2f} um/ms") - - if np.sum(channels_above) < min_channels_for_velocity: - # if DEBUG: - # ax.set_title("NaN velocity - not enough channels") - return np.nan + global DEBUG + if DEBUG: + import matplotlib.pyplot as plt + + fig, axs = plt.subplots(ncols=2, figsize=(10, 7)) + offset = 1.2 * np.max(np.ptp(template, axis=0)) + ts = np.arange(template.shape[0]) / sampling_frequency * 1000 - max_peak_time + (channel_indices_above,) = np.nonzero(channels_above) + for i, single_template in enumerate(template.T): + color = "r" if i in channel_indices_above else "k" + axs[0].plot(ts, single_template + i * offset, color=color) + axs[0].axvline(0, color="g", ls="--") + axs[1].plot(peak_times_ms_above, distances_um_above, "o") + x = np.linspace(peak_times_ms_above.min(), peak_times_ms_above.max(), 20) + axs[1].plot(x, intercept + x * velocity_above) + axs[1].set_xlabel("Peak time (ms)") + axs[1].set_ylabel("Distance from max channel (um)") + fig.suptitle( + f"Velocity above: {velocity_above:.2f} um/ms - score {score:.2f} - channels: {np.sum(channels_above)}" + ) + plt.show() + + if np.sum(channels_above) < min_channels_for_velocity or score < min_r2_velocity: + velocity_above = np.nan - if score < min_r2_for_velocity: - # if DEBUG: - # ax.set_title(f"NaN velocity - R2 is too low: {score:.2f}") - return np.nan return velocity_above -def get_velocity_below(template, channel_locations, **kwargs): +def get_velocity_below(template, channel_locations, sampling_frequency, **kwargs): """ Compute the velocity below the max channel of the template. @@ -670,55 +696,70 @@ def get_velocity_below(template, channel_locations, **kwargs): **kwargs: Required kwargs: - depth_direction: the direction to compute velocity above and below ("x", "y", or "z") - min_channels_for_velocity: the minimum number of channels above or below to compute velocity - - min_r2_for_velocity: the minimum r2 to accept the velocity fit - - sampling_frequency: the sampling frequency + - min_r2_velocity: the minimum r2 to accept the velocity fit + - same_x: whether to transform the template and channel locations to have the same x coordinate """ assert "depth_direction" in kwargs, "depth_direction must be given as kwarg" assert "min_channels_for_velocity" in kwargs, "min_channels_for_velocity must be given as kwarg" - assert "min_r2_for_velocity" in kwargs, "min_r2_for_velocity must be given as kwarg" - direction = kwargs["depth_direction"] + assert "min_r2_velocity" in kwargs, "min_r2_velocity must be given as kwarg" + assert "same_x" in kwargs, "same_x must be given as kwarg" + + depth_direction = kwargs["depth_direction"] min_channels_for_velocity = kwargs["min_channels_for_velocity"] - min_r2_for_velocity = kwargs["min_r2_for_velocity"] - direction_index = ["x", "y", "z"].index(direction) + min_r2_velocity = kwargs["min_r2_velocity"] + same_x = kwargs["same_x"] + + direction_index = ["x", "y", "z"].index(depth_direction) + template, channel_locations = sort_template_and_locations(template, channel_locations, depth_direction) + + if same_x: + template, channel_locations = transform_same_x(template, channel_locations) # find location of max channel max_sample_idx, max_channel_idx = np.unravel_index(np.argmin(template), template.shape) + max_peak_time = max_sample_idx / sampling_frequency * 1000 max_channel_location = channel_locations[max_channel_idx] - sampling_frequency = kwargs["sampling_frequency"] channels_below = channel_locations[:, direction_index] <= max_channel_location[direction_index] # we only consider samples forward in time with respect to the max channel - template_below = template[max_sample_idx:, channels_below] + # template_below = template[max_sample_idx:, channels_below] + template_below = template[:, channels_below] channel_locations_below = channel_locations[channels_below] - peak_times_ms_below = np.argmin(template_below, 0) / sampling_frequency * 1000 + peak_times_ms_below = np.argmin(template_below, 0) / sampling_frequency * 1000 - max_peak_time distances_um_below = np.array([np.linalg.norm(cl - max_channel_location) for cl in channel_locations_below]) velocity_below, intercept, score = fit_velocity(peak_times_ms_below, distances_um_below) - # if DEBUG: - # fig, ax = plt.subplots() - # ax.plot(peak_times_ms_below, distances_um_below, "o") - # x = np.linspace(peak_times_ms_below.min(), peak_times_ms_below.max(), 20) - # ax.plot(x, intercept + x * velocity_below) - # ax.set_xlabel("Peak time (ms)") - # ax.set_ylabel("Distance from max channel (um)") - # ax.set_title(f"Velocity below: {np.round(velocity_below, 3)} um/ms") - - if np.sum(channels_below) < min_channels_for_velocity: - # if DEBUG: - # ax.set_title("NaN velocity - not enough channels") - return np.nan + global DEBUG + if DEBUG: + import matplotlib.pyplot as plt + + fig, axs = plt.subplots(ncols=2, figsize=(10, 7)) + offset = 1.2 * np.max(np.ptp(template, axis=0)) + ts = np.arange(template.shape[0]) / sampling_frequency * 1000 - max_peak_time + (channel_indices_below,) = np.nonzero(channels_below) + for i, single_template in enumerate(template.T): + color = "r" if i in channel_indices_below else "k" + axs[0].plot(ts, single_template + i * offset, color=color) + axs[0].axvline(0, color="g", ls="--") + axs[1].plot(peak_times_ms_below, distances_um_below, "o") + x = np.linspace(peak_times_ms_below.min(), peak_times_ms_below.max(), 20) + axs[1].plot(x, intercept + x * velocity_below) + axs[1].set_xlabel("Peak time (ms)") + axs[1].set_ylabel("Distance from max channel (um)") + fig.suptitle( + f"Velocity below: {np.round(velocity_below, 3)} um/ms - score {score:.2f} - channels: {np.sum(channels_below)}" + ) + plt.show() - if score < min_r2_for_velocity: - # if DEBUG: - # ax.set_title(f"NaN velocity - R2 is too low: {np.round(score, 3)}") - return np.nan + if np.sum(channels_below) < min_channels_for_velocity or score < min_r2_velocity: + velocity_below = np.nan return velocity_below -def get_exp_decay(template, channel_locations, **kwargs): +def get_exp_decay(template, channel_locations, sampling_frequency=None, **kwargs): """ Compute the exponential decay of the template amplitude over distance. @@ -730,14 +771,18 @@ def get_exp_decay(template, channel_locations, **kwargs): The channel locations (num_channels, 2) **kwargs: Required kwargs: - exp_peak_function: the function to use to compute the peak amplitude for the exp decay ("ptp" or "min") + - min_r2_exp_decay: the minimum r2 to accept the exp decay fit """ from scipy.optimize import curve_fit + from sklearn.metrics import r2_score - def exp_decay(x, a, b, c): - return a * np.exp(-b * x) + c + def exp_decay(x, decay, amp0, offset): + return amp0 * np.exp(-decay * x) + offset assert "exp_peak_function" in kwargs, "exp_peak_function must be given as kwarg" exp_peak_function = kwargs["exp_peak_function"] + assert "min_r2_exp_decay" in kwargs, "min_r2_exp_decay must be given as kwarg" + min_r2_exp_decay = kwargs["min_r2_exp_decay"] # exp decay fit if exp_peak_function == "ptp": fun = np.ptp @@ -747,25 +792,49 @@ def exp_decay(x, a, b, c): max_channel_location = channel_locations[np.argmax(peak_amplitudes)] channel_distances = np.array([np.linalg.norm(cl - max_channel_location) for cl in channel_locations]) distances_sort_indices = np.argsort(channel_distances) - channel_distances_sorted = channel_distances[distances_sort_indices] - peak_amplitudes_sorted = peak_amplitudes[distances_sort_indices] + # np.float128 avoids overflow error + channel_distances_sorted = channel_distances[distances_sort_indices].astype(np.float128) + peak_amplitudes_sorted = peak_amplitudes[distances_sort_indices].astype(np.float128) try: - popt, _ = curve_fit(exp_decay, channel_distances_sorted, peak_amplitudes_sorted) - exp_decay_value = popt[1] - # if DEBUG: - # fig, ax = plt.subplots() - # ax.plot(channel_distances_sorted, peak_amplitudes_sorted, "o") - # x = np.arange(channel_distances_sorted.min(), channel_distances_sorted.max()) - # ax.plot(x, exp_decay(x, *popt)) - # ax.set_xlabel("Distance from max channel (um)") - # ax.set_ylabel("Peak amplitude") - # ax.set_title(f"Exp decay: {np.round(exp_decay_value, 3)}") + amp0 = peak_amplitudes_sorted[0] + offset0 = np.min(peak_amplitudes_sorted) + + popt, _ = curve_fit( + exp_decay, + channel_distances_sorted, + peak_amplitudes_sorted, + bounds=([1e-5, amp0 - 0.5 * amp0, 0], [2, amp0 + 0.5 * amp0, 2 * offset0]), + p0=[1e-3, peak_amplitudes_sorted[0], offset0], + ) + r2 = r2_score(peak_amplitudes_sorted, exp_decay(channel_distances_sorted, *popt)) + exp_decay_value = popt[0] + + if r2 < min_r2_exp_decay: + exp_decay_value = np.nan + + global DEBUG + if DEBUG: + import matplotlib.pyplot as plt + + fig, ax = plt.subplots() + ax.plot(channel_distances_sorted, peak_amplitudes_sorted, "o") + x = np.arange(channel_distances_sorted.min(), channel_distances_sorted.max()) + ax.plot(x, exp_decay(x, *popt)) + ax.set_xlabel("Distance from max channel (um)") + ax.set_ylabel("Peak amplitude") + ax.set_title( + f"Exp decay: {np.round(exp_decay_value, 3)} - Amp: {np.round(popt[1], 3)} - Offset: {np.round(popt[2], 3)} - " + f"R2: {np.round(r2, 4)}" + ) + fig.suptitle("Exp decay") + plt.show() except: exp_decay_value = np.nan + return exp_decay_value -def get_spread(template, channel_locations, **kwargs): +def get_spread(template, channel_locations, sampling_frequency, **kwargs): """ Compute the spread of the template amplitude over distance. @@ -783,23 +852,49 @@ def get_spread(template, channel_locations, **kwargs): depth_direction = kwargs["depth_direction"] assert "spread_threshold" in kwargs, "spread_threshold must be given as kwarg" spread_threshold = kwargs["spread_threshold"] + assert "spread_smooth_um" in kwargs, "spread_smooth_um must be given as kwarg" + spread_smooth_um = kwargs["spread_smooth_um"] + assert "same_x" in kwargs, "same_x must be given as kwarg" + same_x = kwargs["same_x"] direction_index = ["x", "y", "z"].index(depth_direction) + template, channel_locations = sort_template_and_locations(template, channel_locations, depth_direction) + + if same_x: + template, channel_locations = transform_same_x(template, channel_locations) MM = np.ptp(template, 0) MM = MM / np.max(MM) + channel_depths = channel_locations[:, direction_index] + + if spread_smooth_um is not None and spread_smooth_um > 0: + from scipy.ndimage import gaussian_filter1d + + spread_sigma = spread_smooth_um / np.median(np.diff(np.unique(channel_depths))) + MM = gaussian_filter1d(MM, spread_sigma) + channel_locations_above_theshold = channel_locations[MM > spread_threshold] channel_depth_above_theshold = channel_locations_above_theshold[:, direction_index] spread = np.ptp(channel_depth_above_theshold) - # if DEBUG: - # fig, ax = plt.subplots() - # channel_depths = channel_locations[:, direction_index] - # sort_indices = np.argsort(channel_depths) - # ax.plot(channel_depths[sort_indices], MM[sort_indices], "o-") - # ax.axhline(spread_threshold, ls="--", color="r") - # ax.set_xlabel("Depth (um)") - # ax.set_ylabel("Amplitude") - # ax.set_title(f"Spread: {np.round(spread, 3)} um") + global DEBUG + if DEBUG: + import matplotlib.pyplot as plt + + fig, axs = plt.subplots(ncols=2, figsize=(10, 7)) + axs[0].imshow( + template.T, + aspect="auto", + origin="lower", + extent=[0, template.shape[0] / sampling_frequency, channel_depths[0], channel_depths[1]], + ) + axs[1].plot(channel_depths, MM, "o-") + axs[1].axhline(spread_threshold, ls="--", color="r") + axs[1].set_xlabel("Depth (um)") + axs[1].set_ylabel("Amplitude") + axs[1].set_title(f"Spread: {np.round(spread, 3)} um") + fig.suptitle("Spread") + plt.show() + return spread From c1cd889beacca66f43262f95e18033100f98d59d Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 29 Sep 2023 13:19:35 +0200 Subject: [PATCH 5/7] Add 'column_range' and simplify dimension handling --- .../postprocessing/template_metrics.py | 76 +++++++++++-------- 1 file changed, 44 insertions(+), 32 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 090dae4567..774ebab4a9 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -207,7 +207,7 @@ def get_extension_function(): min_r2_exp_decay=0.5, spread_threshold=0.2, spread_smooth_um=20, - same_x=False, + column_range=None, ) @@ -265,7 +265,13 @@ def compute_template_metrics( * min_channels_for_velocity: the minimum number of channels above or below to compute velocity, default: 5 * min_r2_velocity: the minimum r2 to accept the velocity fit, default: 0.7 * exp_peak_function: the function to use to compute the peak amplitude for the exp decay, default: "ptp" + * min_r2_exp_decay: the minimum r2 to accept the exp decay fit, default: 0.5 * spread_threshold: the threshold to compute the spread, default: 0.2 + * spread_smooth_um: the smoothing in um to compute the spread, default: 20 + * column_range: the range in um in the horizontal direction to consider channels for velocity, default: None + - If None, all channels all channels are considered + - If 0 or 1, only the "column" that includes the max channel is considered + - If > 1, only channels within range (+/-) um from the max channel horizontal position are used Returns ------- @@ -278,6 +284,7 @@ def compute_template_metrics( ----- If any multi-channel metric is in the metric_names or include_multi_channel_metrics is True, sparsity must be None, so that one metric value will be computed per unit. + For multi-channel metrocs, 3D channel locations are not supported. By default, the depth direction is "y". """ if debug_plots: global DEBUG @@ -294,6 +301,9 @@ def compute_template_metrics( "If multi-channel metrics are computed, sparsity must be None, " "so that each unit will correspond to 1 row of the output dataframe." ) + assert ( + waveform_extractor.get_channel_locations().shape[1] == 2 + ), "If multi-channel metrics are computed, channel locations must be 2D." default_kwargs = _default_function_kwargs.copy() if metrics_kwargs is None: metrics_kwargs = default_kwargs @@ -579,17 +589,22 @@ def get_num_negative_peaks(template_single, sampling_frequency, **kwargs): # Multi-channel metrics -def transform_same_x(template, channel_locations): - max_channel_x = channel_locations[np.argmax(np.ptp(template, axis=0)), 0] - same_x_mask = channel_locations[:, 0] == max_channel_x - channel_locations_same_x = channel_locations[same_x_mask] - template_same_x = template[:, same_x_mask] - return template_same_x, channel_locations_same_x +def transform_column_range(template, channel_locations, column_range, depth_direction="y"): + column_dim = 0 if depth_direction == "y" else 1 + if column_range is None: + template_column_range = template + channel_locations_column_range = channel_locations + else: + max_channel_x = channel_locations[np.argmax(np.ptp(template, axis=0)), 0] + column_mask = np.abs(channel_locations[:, column_dim] - max_channel_x) <= column_range + template_column_range = template[:, column_mask] + channel_locations_column_range = channel_locations[column_mask] + return template_column_range, channel_locations_column_range def sort_template_and_locations(template, channel_locations, depth_direction="y"): - direction_index = ["x", "y", "z"].index(depth_direction) - sort_indices = np.argsort(channel_locations[:, direction_index]) + depth_dim = 1 if depth_direction == "y" else 0 + sort_indices = np.argsort(channel_locations[:, depth_dim]) return template[:, sort_indices], channel_locations[sort_indices, :] @@ -621,29 +636,28 @@ def get_velocity_above(template, channel_locations, sampling_frequency, **kwargs - depth_direction: the direction to compute velocity above and below ("x", "y", or "z") - min_channels_for_velocity: the minimum number of channels above or below to compute velocity - min_r2_velocity: the minimum r2 to accept the velocity fit + - column_range: the range in um in the x-direction to consider channels for velocity """ assert "depth_direction" in kwargs, "depth_direction must be given as kwarg" assert "min_channels_for_velocity" in kwargs, "min_channels_for_velocity must be given as kwarg" assert "min_r2_velocity" in kwargs, "min_r2_velocity must be given as kwarg" - assert "same_x" in kwargs, "same_x must be given as kwarg" + assert "column_range" in kwargs, "column_range must be given as kwarg" depth_direction = kwargs["depth_direction"] min_channels_for_velocity = kwargs["min_channels_for_velocity"] min_r2_velocity = kwargs["min_r2_velocity"] - same_x = kwargs["same_x"] + column_range = kwargs["column_range"] - direction_index = ["x", "y", "z"].index(depth_direction) + depth_dim = 1 if depth_direction == "y" else 0 + template, channel_locations = transform_column_range(template, channel_locations, column_range, depth_direction) template, channel_locations = sort_template_and_locations(template, channel_locations, depth_direction) - if same_x: - template, channel_locations = transform_same_x(template, channel_locations) - # find location of max channel max_sample_idx, max_channel_idx = np.unravel_index(np.argmin(template), template.shape) max_peak_time = max_sample_idx / sampling_frequency * 1000 max_channel_location = channel_locations[max_channel_idx] - channels_above = channel_locations[:, direction_index] >= max_channel_location[direction_index] + channels_above = channel_locations[:, depth_dim] >= max_channel_location[depth_dim] # we only consider samples forward in time with respect to the max channel # TODO: not sure @@ -697,30 +711,28 @@ def get_velocity_below(template, channel_locations, sampling_frequency, **kwargs - depth_direction: the direction to compute velocity above and below ("x", "y", or "z") - min_channels_for_velocity: the minimum number of channels above or below to compute velocity - min_r2_velocity: the minimum r2 to accept the velocity fit - - same_x: whether to transform the template and channel locations to have the same x coordinate + - column_range: the range in um in the x-direction to consider channels for velocity """ assert "depth_direction" in kwargs, "depth_direction must be given as kwarg" assert "min_channels_for_velocity" in kwargs, "min_channels_for_velocity must be given as kwarg" assert "min_r2_velocity" in kwargs, "min_r2_velocity must be given as kwarg" - assert "same_x" in kwargs, "same_x must be given as kwarg" + assert "column_range" in kwargs, "column_range must be given as kwarg" depth_direction = kwargs["depth_direction"] min_channels_for_velocity = kwargs["min_channels_for_velocity"] min_r2_velocity = kwargs["min_r2_velocity"] - same_x = kwargs["same_x"] + column_range = kwargs["column_range"] - direction_index = ["x", "y", "z"].index(depth_direction) + depth_dim = 1 if depth_direction == "y" else 0 + template, channel_locations = transform_column_range(template, channel_locations, column_range) template, channel_locations = sort_template_and_locations(template, channel_locations, depth_direction) - if same_x: - template, channel_locations = transform_same_x(template, channel_locations) - # find location of max channel max_sample_idx, max_channel_idx = np.unravel_index(np.argmin(template), template.shape) max_peak_time = max_sample_idx / sampling_frequency * 1000 max_channel_location = channel_locations[max_channel_idx] - channels_below = channel_locations[:, direction_index] <= max_channel_location[direction_index] + channels_below = channel_locations[:, depth_dim] <= max_channel_location[depth_dim] # we only consider samples forward in time with respect to the max channel # template_below = template[max_sample_idx:, channels_below] @@ -847,6 +859,7 @@ def get_spread(template, channel_locations, sampling_frequency, **kwargs): **kwargs: Required kwargs: - depth_direction: the direction to compute velocity above and below ("x", "y", or "z") - spread_threshold: the threshold to compute the spread + - column_range: the range in um in the x-direction to consider channels for velocity """ assert "depth_direction" in kwargs, "depth_direction must be given as kwarg" depth_direction = kwargs["depth_direction"] @@ -854,17 +867,16 @@ def get_spread(template, channel_locations, sampling_frequency, **kwargs): spread_threshold = kwargs["spread_threshold"] assert "spread_smooth_um" in kwargs, "spread_smooth_um must be given as kwarg" spread_smooth_um = kwargs["spread_smooth_um"] - assert "same_x" in kwargs, "same_x must be given as kwarg" - same_x = kwargs["same_x"] + assert "column_range" in kwargs, "column_range must be given as kwarg" + column_range = kwargs["column_range"] - direction_index = ["x", "y", "z"].index(depth_direction) + depth_dim = 1 if depth_direction == "y" else 0 + template, channel_locations = transform_column_range(template, channel_locations, column_range) template, channel_locations = sort_template_and_locations(template, channel_locations, depth_direction) - if same_x: - template, channel_locations = transform_same_x(template, channel_locations) MM = np.ptp(template, 0) MM = MM / np.max(MM) - channel_depths = channel_locations[:, direction_index] + channel_depths = channel_locations[:, depth_dim] if spread_smooth_um is not None and spread_smooth_um > 0: from scipy.ndimage import gaussian_filter1d @@ -873,7 +885,7 @@ def get_spread(template, channel_locations, sampling_frequency, **kwargs): MM = gaussian_filter1d(MM, spread_sigma) channel_locations_above_theshold = channel_locations[MM > spread_threshold] - channel_depth_above_theshold = channel_locations_above_theshold[:, direction_index] + channel_depth_above_theshold = channel_locations_above_theshold[:, depth_dim] spread = np.ptp(channel_depth_above_theshold) global DEBUG @@ -885,7 +897,7 @@ def get_spread(template, channel_locations, sampling_frequency, **kwargs): template.T, aspect="auto", origin="lower", - extent=[0, template.shape[0] / sampling_frequency, channel_depths[0], channel_depths[1]], + extent=[0, template.shape[0] / sampling_frequency, channel_depths[0], channel_depths[-1]], ) axs[1].plot(channel_depths, MM, "o-") axs[1].axhline(spread_threshold, ls="--", color="r") From ac84b25530b04e30c80eba7c474be61279a7dd1f Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Sun, 1 Oct 2023 15:11:30 +0200 Subject: [PATCH 6/7] Fix docstrings --- .../postprocessing/template_metrics.py | 67 ++++++++++++++----- 1 file changed, 50 insertions(+), 17 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 774ebab4a9..82f55483b4 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -4,12 +4,13 @@ 22/04/2020 """ import numpy as np +import warnings +from typing import Optional from copy import deepcopy -from ..core import WaveformExtractor +from ..core import WaveformExtractor, ChannelSparsity from ..core.template_tools import get_template_extremum_channel from ..core.waveform_extractor import BaseWaveformExtractorExtension -import warnings global DEBUG @@ -211,16 +212,17 @@ def get_extension_function(): ) +# TODO: add typing def compute_template_metrics( waveform_extractor, - load_if_exists=False, - metric_names=None, - peak_sign="neg", - upsampling_factor=10, - sparsity=None, - include_multi_channel_metrics=False, - metrics_kwargs=None, - debug_plots=False, + load_if_exists: bool = False, + metric_names: Optional[list[str]] = None, + peak_sign: Optional[str] = "neg", + upsampling_factor: int = 10, + sparsity: Optional[ChannelSparsity] = None, + include_multi_channel_metrics: bool = False, + metrics_kwargs: dict = None, + debug_plots: bool = False, ): """ Compute template metrics including: @@ -247,13 +249,13 @@ def compute_template_metrics( metric_names : list, optional List of metrics to compute (see si.postprocessing.get_template_metric_names()), by default None peak_sign : {"neg", "pos"}, default: "neg" - The peak sign + Whether to use the positive ("pos") or negative ("neg") peaks to estimate extremum channels. upsampling_factor : int, default: 10 The upsampling factor to upsample the templates - sparsity: dict or None, default: None - Default is sparsity=None and template metric is computed on extremum channel only. - If given, the dictionary should contain a unit ids as keys and a channel id or a list of channel ids as values. - For more generating a sparsity dict, see the postprocessing.compute_sparsity() function. + sparsity: ChannelSparsity or None, default: None + If None, template metrics are computed on the extremum channel only. + If sparsity is given, template metrics are computed on all sparse channels of each unit. + For more on generating a ChannelSparsity, see the `~spikeinterface.compute_sparsity()` function. include_multi_channel_metrics: bool, default: False Whether to compute multi-channel metrics metrics_kwargs: dict @@ -261,7 +263,7 @@ def compute_template_metrics( * recovery_window_ms: the window in ms after the peak to compute the recovery_slope, default: 0.7 * peak_relative_threshold: the relative threshold to detect positive and negative peaks, default: 0.2 * peak_width_ms: the width in samples to detect peaks, default: 0.2 - * depth_direction: the direction to compute velocity above and below, default: "y" + * depth_direction: the direction to compute velocity above and below, default: "y" (see notes) * min_channels_for_velocity: the minimum number of channels above or below to compute velocity, default: 5 * min_r2_velocity: the minimum r2 to accept the velocity fit, default: 0.7 * exp_peak_function: the function to use to compute the peak amplitude for the exp decay, default: "ptp" @@ -284,7 +286,7 @@ def compute_template_metrics( ----- If any multi-channel metric is in the metric_names or include_multi_channel_metrics is True, sparsity must be None, so that one metric value will be computed per unit. - For multi-channel metrocs, 3D channel locations are not supported. By default, the depth direction is "y". + For multi-channel metrics, 3D channel locations are not supported. By default, the depth direction is "y". """ if debug_plots: global DEBUG @@ -359,6 +361,8 @@ def get_peak_to_valley(template_single, sampling_frequency, trough_idx=None, pea ---------- template_single: numpy.ndarray The 1D template waveform + sampling_frequency : float + The sampling frequency of the template trough_idx: int, default: None The index of the trough peak_idx: int, default: None @@ -383,6 +387,8 @@ def get_peak_trough_ratio(template_single, sampling_frequency=None, trough_idx=N ---------- template_single: numpy.ndarray The 1D template waveform + sampling_frequency : float + The sampling frequency of the template trough_idx: int, default: None The index of the trough peak_idx: int, default: None @@ -407,6 +413,8 @@ def get_half_width(template_single, sampling_frequency, trough_idx=None, peak_id ---------- template_single: numpy.ndarray The 1D template waveform + sampling_frequency : float + The sampling frequency of the template trough_idx: int, default: None The index of the trough peak_idx: int, default: None @@ -458,6 +466,8 @@ def get_repolarization_slope(template_single, sampling_frequency, trough_idx=Non ---------- template_single: numpy.ndarray The 1D template waveform + sampling_frequency : float + The sampling frequency of the template trough_idx: int, default: None The index of the trough """ @@ -499,6 +509,8 @@ def get_recovery_slope(template_single, sampling_frequency, peak_idx=None, **kwa ---------- template_single: numpy.ndarray The 1D template waveform + sampling_frequency : float + The sampling frequency of the template peak_idx: int, default: None The index of the peak **kwargs: Required kwargs: @@ -530,6 +542,8 @@ def get_num_positive_peaks(template_single, sampling_frequency, **kwargs): ---------- template_single: numpy.ndarray The 1D template waveform + sampling_frequency : float + The sampling frequency of the template **kwargs: Required kwargs: - peak_relative_threshold: the relative threshold to detect positive and negative peaks - peak_width_ms: the width in samples to detect peaks @@ -556,6 +570,8 @@ def get_num_negative_peaks(template_single, sampling_frequency, **kwargs): ---------- template_single: numpy.ndarray The 1D template waveform + sampling_frequency : float + The sampling frequency of the template **kwargs: Required kwargs: - peak_relative_threshold: the relative threshold to detect positive and negative peaks - peak_width_ms: the width in samples to detect peaks @@ -590,6 +606,9 @@ def get_num_negative_peaks(template_single, sampling_frequency, **kwargs): def transform_column_range(template, channel_locations, column_range, depth_direction="y"): + """ + Transform template anch channel locations based on column range. + """ column_dim = 0 if depth_direction == "y" else 1 if column_range is None: template_column_range = template @@ -603,12 +622,18 @@ def transform_column_range(template, channel_locations, column_range, depth_dire def sort_template_and_locations(template, channel_locations, depth_direction="y"): + """ + Sort template and locations. + """ depth_dim = 1 if depth_direction == "y" else 0 sort_indices = np.argsort(channel_locations[:, depth_dim]) return template[:, sort_indices], channel_locations[sort_indices, :] def fit_velocity(peak_times, channel_dist): + """ + Fit velocity from peak times and channel distances using ribust Theilsen estimator. + """ # from scipy.stats import linregress # slope, intercept, _, _, _ = linregress(peak_times, channel_dist) @@ -632,6 +657,8 @@ def get_velocity_above(template, channel_locations, sampling_frequency, **kwargs The template waveform (num_samples, num_channels) channel_locations: numpy.ndarray The channel locations (num_channels, 2) + sampling_frequency : float + The sampling frequency of the template **kwargs: Required kwargs: - depth_direction: the direction to compute velocity above and below ("x", "y", or "z") - min_channels_for_velocity: the minimum number of channels above or below to compute velocity @@ -707,6 +734,8 @@ def get_velocity_below(template, channel_locations, sampling_frequency, **kwargs The template waveform (num_samples, num_channels) channel_locations: numpy.ndarray The channel locations (num_channels, 2) + sampling_frequency : float + The sampling frequency of the template **kwargs: Required kwargs: - depth_direction: the direction to compute velocity above and below ("x", "y", or "z") - min_channels_for_velocity: the minimum number of channels above or below to compute velocity @@ -781,6 +810,8 @@ def get_exp_decay(template, channel_locations, sampling_frequency=None, **kwargs The template waveform (num_samples, num_channels) channel_locations: numpy.ndarray The channel locations (num_channels, 2) + sampling_frequency : float + The sampling frequency of the template **kwargs: Required kwargs: - exp_peak_function: the function to use to compute the peak amplitude for the exp decay ("ptp" or "min") - min_r2_exp_decay: the minimum r2 to accept the exp decay fit @@ -856,6 +887,8 @@ def get_spread(template, channel_locations, sampling_frequency, **kwargs): The template waveform (num_samples, num_channels) channel_locations: numpy.ndarray The channel locations (num_channels, 2) + sampling_frequency : float + The sampling frequency of the template **kwargs: Required kwargs: - depth_direction: the direction to compute velocity above and below ("x", "y", or "z") - spread_threshold: the threshold to compute the spread From 4e3140f58cec52b42563b02a5bfb2d0fdda498c3 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 2 Oct 2023 10:19:09 +0200 Subject: [PATCH 7/7] Remove comment --- src/spikeinterface/postprocessing/template_metrics.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 82f55483b4..3f47c505ad 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -41,7 +41,7 @@ class TemplateMetricsCalculator(BaseWaveformExtractorExtension): extension_name = "template_metrics" min_channels_for_multi_channel_warning = 10 - def __init__(self, waveform_extractor): + def __init__(self, waveform_extractor: WaveformExtractor): BaseWaveformExtractorExtension.__init__(self, waveform_extractor) def _set_params( @@ -212,7 +212,6 @@ def get_extension_function(): ) -# TODO: add typing def compute_template_metrics( waveform_extractor, load_if_exists: bool = False,