diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index 9f85603e51..79456a40ce 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -811,14 +811,30 @@ def select_units(self, unit_ids, new_folder=None, use_relative_path: bool = Fals sparsity = ChannelSparsity(mask, unit_ids, self.channel_ids) else: sparsity = None - we = WaveformExtractor.create(self.recording, sorting, folder=None, mode="memory", sparsity=sparsity) - we.set_params(**self._params) + if self.has_recording(): + we = WaveformExtractor.create(self.recording, sorting, folder=None, mode="memory", sparsity=sparsity) + else: + we = WaveformExtractor( + recording=None, + sorting=sorting, + folder=None, + sparsity=sparsity, + rec_attributes=self._rec_attributes, + allow_unfiltered=True, + ) + we._params = self._params # copy memory objects if self.has_waveforms(): we._memory_objects = {"wfs_arrays": {}, "sampled_indices": {}} for unit_id in unit_ids: - we._memory_objects["wfs_arrays"][unit_id] = self._memory_objects["wfs_arrays"][unit_id] - we._memory_objects["sampled_indices"][unit_id] = self._memory_objects["sampled_indices"][unit_id] + if self.format == "memory": + we._memory_objects["wfs_arrays"][unit_id] = self._memory_objects["wfs_arrays"][unit_id] + we._memory_objects["sampled_indices"][unit_id] = self._memory_objects["sampled_indices"][ + unit_id + ] + else: + we._memory_objects["wfs_arrays"][unit_id] = self.get_waveforms(unit_id) + we._memory_objects["sampled_indices"][unit_id] = self.get_sampled_indices(unit_id) # finally select extensions data for ext_name in self.get_available_extension_names(): diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index ea44dea9cb..090dae4567 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -11,12 +11,9 @@ from ..core.waveform_extractor import BaseWaveformExtractorExtension import warnings -# DEBUG = True -# if DEBUG: -# import matplotlib.pyplot as plt -# plt.ion() -# plt.show() +global DEBUG +DEBUG = False def get_single_channel_template_metric_names(): @@ -52,20 +49,20 @@ def _set_params( peak_sign="neg", upsampling_factor=10, sparsity=None, - functions_kwargs=None, + metrics_kwargs=None, include_multi_channel_metrics=False, ): if metric_names is None: metric_names = get_single_channel_template_metric_names() if include_multi_channel_metrics: metric_names += get_multi_channel_template_metric_names() - functions_kwargs = functions_kwargs or dict() + metrics_kwargs = metrics_kwargs or dict() params = dict( metric_names=[str(name) for name in metric_names], sparsity=sparsity, peak_sign=peak_sign, upsampling_factor=int(upsampling_factor), - functions_kwargs=functions_kwargs, + metrics_kwargs=metrics_kwargs, ) return params @@ -141,7 +138,7 @@ def _run(self): sampling_frequency=sampling_frequency_up, trough_idx=trough_idx, peak_idx=peak_idx, - **self._params["functions_kwargs"], + **self._params["metrics_kwargs"], ) template_metrics.at[index, metric_name] = value @@ -173,7 +170,7 @@ def _run(self): template_upsampled, channel_locations=channel_locations_sparse, sampling_frequency=sampling_frequency_up, - **self._params["functions_kwargs"], + **self._params["metrics_kwargs"], ) template_metrics.at[index, metric_name] = value self._extension_data["metrics"] = template_metrics @@ -199,6 +196,21 @@ def get_extension_function(): WaveformExtractor.register_extension(TemplateMetricsCalculator) +_default_function_kwargs = dict( + recovery_window_ms=0.7, + peak_relative_threshold=0.2, + peak_width_ms=0.1, + depth_direction="y", + min_channels_for_velocity=5, + min_r2_velocity=0.5, + exp_peak_function="ptp", + min_r2_exp_decay=0.5, + spread_threshold=0.2, + spread_smooth_um=20, + same_x=False, +) + + def compute_template_metrics( waveform_extractor, load_if_exists=False, @@ -207,16 +219,8 @@ def compute_template_metrics( upsampling_factor=10, sparsity=None, include_multi_channel_metrics=False, - functions_kwargs=dict( - recovery_window_ms=0.7, - peak_relative_threshold=0.2, - peak_width_ms=0.2, - depth_direction="y", - min_channels_for_velocity=5, - min_r2_for_velocity=0.5, - exp_peak_function="ptp", - spread_threshold=0.2, - ), + metrics_kwargs=None, + debug_plots=False, ): """ Compute template metrics including: @@ -252,14 +256,14 @@ def compute_template_metrics( For more generating a sparsity dict, see the postprocessing.compute_sparsity() function. include_multi_channel_metrics: bool, default: False Whether to compute multi-channel metrics - functions_kwargs: dict + metrics_kwargs: dict Additional arguments to pass to the metric functions. Including: * recovery_window_ms: the window in ms after the peak to compute the recovery_slope, default: 0.7 * peak_relative_threshold: the relative threshold to detect positive and negative peaks, default: 0.2 * peak_width_ms: the width in samples to detect peaks, default: 0.2 * depth_direction: the direction to compute velocity above and below, default: "y" * min_channels_for_velocity: the minimum number of channels above or below to compute velocity, default: 5 - * min_r2_for_velocity: the minimum r2 to accept the velocity fit, default: 0.7 + * min_r2_velocity: the minimum r2 to accept the velocity fit, default: 0.7 * exp_peak_function: the function to use to compute the peak amplitude for the exp decay, default: "ptp" * spread_threshold: the threshold to compute the spread, default: 0.2 @@ -275,6 +279,9 @@ def compute_template_metrics( If any multi-channel metric is in the metric_names or include_multi_channel_metrics is True, sparsity must be None, so that one metric value will be computed per unit. """ + if debug_plots: + global DEBUG + DEBUG = True if load_if_exists and waveform_extractor.is_extension(TemplateMetricsCalculator.extension_name): tmc = waveform_extractor.load_extension(TemplateMetricsCalculator.extension_name) else: @@ -287,13 +294,19 @@ def compute_template_metrics( "If multi-channel metrics are computed, sparsity must be None, " "so that each unit will correspond to 1 row of the output dataframe." ) + default_kwargs = _default_function_kwargs.copy() + if metrics_kwargs is None: + metrics_kwargs = default_kwargs + else: + default_kwargs.update(metrics_kwargs) + metrics_kwargs = default_kwargs tmc.set_params( metric_names=metric_names, peak_sign=peak_sign, upsampling_factor=upsampling_factor, sparsity=sparsity, include_multi_channel_metrics=include_multi_channel_metrics, - functions_kwargs=functions_kwargs, + metrics_kwargs=metrics_kwargs, ) tmc.run() @@ -328,7 +341,7 @@ def get_trough_and_peak_idx(template): ######################################################################################### # Single-channel metrics -def get_peak_to_valley(template_single, trough_idx=None, peak_idx=None, **kwargs): +def get_peak_to_valley(template_single, sampling_frequency, trough_idx=None, peak_idx=None, **kwargs): """ Return the peak to valley duration in seconds of input waveforms. @@ -340,22 +353,19 @@ def get_peak_to_valley(template_single, trough_idx=None, peak_idx=None, **kwargs The index of the trough peak_idx: int, default: None The index of the peak - **kwargs: Required kwargs: - - sampling_frequency: the sampling frequency Returns ------- ptv: float The peak to valley duration in seconds """ - sampling_frequency = kwargs["sampling_frequency"] if trough_idx is None or peak_idx is None: trough_idx, peak_idx = get_trough_and_peak_idx(template_single) ptv = (peak_idx - trough_idx) / sampling_frequency return ptv -def get_peak_trough_ratio(template_single, trough_idx=None, peak_idx=None, **kwargs): +def get_peak_trough_ratio(template_single, sampling_frequency=None, trough_idx=None, peak_idx=None, **kwargs): """ Return the peak to trough ratio of input waveforms. @@ -367,8 +377,6 @@ def get_peak_trough_ratio(template_single, trough_idx=None, peak_idx=None, **kwa The index of the trough peak_idx: int, default: None The index of the peak - **kwargs: Required kwargs: - - sampling_frequency: the sampling frequency Returns ------- @@ -381,7 +389,7 @@ def get_peak_trough_ratio(template_single, trough_idx=None, peak_idx=None, **kwa return ptratio -def get_half_width(template_single, trough_idx=None, peak_idx=None, **kwargs): +def get_half_width(template_single, sampling_frequency, trough_idx=None, peak_idx=None, **kwargs): """ Return the half width of input waveforms in seconds. @@ -393,8 +401,6 @@ def get_half_width(template_single, trough_idx=None, peak_idx=None, **kwargs): The index of the trough peak_idx: int, default: None The index of the peak - **kwargs: Required kwargs: - - sampling_frequency: the sampling frequency Returns ------- @@ -403,7 +409,6 @@ def get_half_width(template_single, trough_idx=None, peak_idx=None, **kwargs): """ if trough_idx is None or peak_idx is None: trough_idx, peak_idx = get_trough_and_peak_idx(template_single) - sampling_frequency = kwargs["sampling_frequency"] if peak_idx == 0: return np.nan @@ -428,7 +433,7 @@ def get_half_width(template_single, trough_idx=None, peak_idx=None, **kwargs): return hw -def get_repolarization_slope(template_single, trough_idx=None, **kwargs): +def get_repolarization_slope(template_single, sampling_frequency, trough_idx=None, **kwargs): """ Return slope of repolarization period between trough and baseline @@ -445,12 +450,9 @@ def get_repolarization_slope(template_single, trough_idx=None, **kwargs): The 1D template waveform trough_idx: int, default: None The index of the trough - **kwargs: Required kwargs: - - sampling_frequency: the sampling frequency """ if trough_idx is None: trough_idx = get_trough_and_peak_idx(template_single) - sampling_frequency = kwargs["sampling_frequency"] times = np.arange(template_single.shape[0]) / sampling_frequency @@ -472,7 +474,7 @@ def get_repolarization_slope(template_single, trough_idx=None, **kwargs): return res.slope -def get_recovery_slope(template_single, peak_idx=None, **kwargs): +def get_recovery_slope(template_single, sampling_frequency, peak_idx=None, **kwargs): """ Return the recovery slope of input waveforms. After repolarization, the neuron hyperpolarizes untill it peaks. The recovery slope is the @@ -490,7 +492,6 @@ def get_recovery_slope(template_single, peak_idx=None, **kwargs): peak_idx: int, default: None The index of the peak **kwargs: Required kwargs: - - sampling_frequency: the sampling frequency - recovery_window_ms: the window in ms after the peak to compute the recovery_slope """ import scipy.stats @@ -499,7 +500,6 @@ def get_recovery_slope(template_single, peak_idx=None, **kwargs): recovery_window_ms = kwargs["recovery_window_ms"] if peak_idx is None: _, peak_idx = get_trough_and_peak_idx(template_single) - sampling_frequency = kwargs["sampling_frequency"] times = np.arange(template_single.shape[0]) / sampling_frequency @@ -512,7 +512,7 @@ def get_recovery_slope(template_single, peak_idx=None, **kwargs): return res.slope -def get_num_positive_peaks(template_single, **kwargs): +def get_num_positive_peaks(template_single, sampling_frequency, **kwargs): """ Count the number of positive peaks in the template. @@ -523,7 +523,6 @@ def get_num_positive_peaks(template_single, **kwargs): **kwargs: Required kwargs: - peak_relative_threshold: the relative threshold to detect positive and negative peaks - peak_width_ms: the width in samples to detect peaks - - sampling_frequency: the sampling frequency """ from scipy.signal import find_peaks @@ -532,14 +531,14 @@ def get_num_positive_peaks(template_single, **kwargs): peak_relative_threshold = kwargs["peak_relative_threshold"] peak_width_ms = kwargs["peak_width_ms"] max_value = np.max(np.abs(template_single)) - peak_width_samples = int(peak_width_ms / 1000 * kwargs["sampling_frequency"]) + peak_width_samples = int(peak_width_ms / 1000 * sampling_frequency) pos_peaks = find_peaks(template_single, height=peak_relative_threshold * max_value, width=peak_width_samples) return len(pos_peaks[0]) -def get_num_negative_peaks(template_single, **kwargs): +def get_num_negative_peaks(template_single, sampling_frequency, **kwargs): """ Count the number of negative peaks in the template. @@ -550,7 +549,6 @@ def get_num_negative_peaks(template_single, **kwargs): **kwargs: Required kwargs: - peak_relative_threshold: the relative threshold to detect positive and negative peaks - peak_width_ms: the width in samples to detect peaks - - sampling_frequency: the sampling frequency """ from scipy.signal import find_peaks @@ -559,7 +557,7 @@ def get_num_negative_peaks(template_single, **kwargs): peak_relative_threshold = kwargs["peak_relative_threshold"] peak_width_ms = kwargs["peak_width_ms"] max_value = np.max(np.abs(template_single)) - peak_width_samples = int(peak_width_ms / 1000 * kwargs["sampling_frequency"]) + peak_width_samples = int(peak_width_ms / 1000 * sampling_frequency) neg_peaks = find_peaks(-template_single, height=peak_relative_threshold * max_value, width=peak_width_samples) @@ -581,6 +579,20 @@ def get_num_negative_peaks(template_single, **kwargs): # Multi-channel metrics +def transform_same_x(template, channel_locations): + max_channel_x = channel_locations[np.argmax(np.ptp(template, axis=0)), 0] + same_x_mask = channel_locations[:, 0] == max_channel_x + channel_locations_same_x = channel_locations[same_x_mask] + template_same_x = template[:, same_x_mask] + return template_same_x, channel_locations_same_x + + +def sort_template_and_locations(template, channel_locations, depth_direction="y"): + direction_index = ["x", "y", "z"].index(depth_direction) + sort_indices = np.argsort(channel_locations[:, direction_index]) + return template[:, sort_indices], channel_locations[sort_indices, :] + + def fit_velocity(peak_times, channel_dist): # from scipy.stats import linregress # slope, intercept, _, _, _ = linregress(peak_times, channel_dist) @@ -595,7 +607,7 @@ def fit_velocity(peak_times, channel_dist): return slope, intercept, score -def get_velocity_above(template, channel_locations, **kwargs): +def get_velocity_above(template, channel_locations, sampling_frequency, **kwargs): """ Compute the velocity above the max channel of the template. @@ -608,56 +620,70 @@ def get_velocity_above(template, channel_locations, **kwargs): **kwargs: Required kwargs: - depth_direction: the direction to compute velocity above and below ("x", "y", or "z") - min_channels_for_velocity: the minimum number of channels above or below to compute velocity - - min_r2_for_velocity: the minimum r2 to accept the velocity fit - - sampling_frequency: the sampling frequency + - min_r2_velocity: the minimum r2 to accept the velocity fit """ assert "depth_direction" in kwargs, "depth_direction must be given as kwarg" assert "min_channels_for_velocity" in kwargs, "min_channels_for_velocity must be given as kwarg" - assert "min_r2_for_velocity" in kwargs, "min_r2_for_velocity must be given as kwarg" + assert "min_r2_velocity" in kwargs, "min_r2_velocity must be given as kwarg" + assert "same_x" in kwargs, "same_x must be given as kwarg" depth_direction = kwargs["depth_direction"] min_channels_for_velocity = kwargs["min_channels_for_velocity"] - min_r2_for_velocity = kwargs["min_r2_for_velocity"] + min_r2_velocity = kwargs["min_r2_velocity"] + same_x = kwargs["same_x"] direction_index = ["x", "y", "z"].index(depth_direction) - sampling_frequency = kwargs["sampling_frequency"] + template, channel_locations = sort_template_and_locations(template, channel_locations, depth_direction) + + if same_x: + template, channel_locations = transform_same_x(template, channel_locations) # find location of max channel max_sample_idx, max_channel_idx = np.unravel_index(np.argmin(template), template.shape) + max_peak_time = max_sample_idx / sampling_frequency * 1000 max_channel_location = channel_locations[max_channel_idx] channels_above = channel_locations[:, direction_index] >= max_channel_location[direction_index] # we only consider samples forward in time with respect to the max channel - template_above = template[max_sample_idx:, channels_above] + # TODO: not sure + # template_above = template[max_sample_idx:, channels_above] + template_above = template[:, channels_above] channel_locations_above = channel_locations[channels_above] - peak_times_ms_above = np.argmin(template_above, 0) / sampling_frequency * 1000 + peak_times_ms_above = np.argmin(template_above, 0) / sampling_frequency * 1000 - max_peak_time distances_um_above = np.array([np.linalg.norm(cl - max_channel_location) for cl in channel_locations_above]) velocity_above, intercept, score = fit_velocity(peak_times_ms_above, distances_um_above) - # if DEBUG: - # fig, ax = plt.subplots() - # ax.plot(peak_times_ms_above, distances_um_above, "o") - # x = np.linspace(peak_times_ms_above.min(), peak_times_ms_above.max(), 20) - # ax.plot(x, intercept + x * velocity_above) - # ax.set_xlabel("Peak time (ms)") - # ax.set_ylabel("Distance from max channel (um)") - # ax.set_title(f"Velocity above: {velocity_above:.2f} um/ms") - - if np.sum(channels_above) < min_channels_for_velocity: - # if DEBUG: - # ax.set_title("NaN velocity - not enough channels") - return np.nan + global DEBUG + if DEBUG: + import matplotlib.pyplot as plt + + fig, axs = plt.subplots(ncols=2, figsize=(10, 7)) + offset = 1.2 * np.max(np.ptp(template, axis=0)) + ts = np.arange(template.shape[0]) / sampling_frequency * 1000 - max_peak_time + (channel_indices_above,) = np.nonzero(channels_above) + for i, single_template in enumerate(template.T): + color = "r" if i in channel_indices_above else "k" + axs[0].plot(ts, single_template + i * offset, color=color) + axs[0].axvline(0, color="g", ls="--") + axs[1].plot(peak_times_ms_above, distances_um_above, "o") + x = np.linspace(peak_times_ms_above.min(), peak_times_ms_above.max(), 20) + axs[1].plot(x, intercept + x * velocity_above) + axs[1].set_xlabel("Peak time (ms)") + axs[1].set_ylabel("Distance from max channel (um)") + fig.suptitle( + f"Velocity above: {velocity_above:.2f} um/ms - score {score:.2f} - channels: {np.sum(channels_above)}" + ) + plt.show() + + if np.sum(channels_above) < min_channels_for_velocity or score < min_r2_velocity: + velocity_above = np.nan - if score < min_r2_for_velocity: - # if DEBUG: - # ax.set_title(f"NaN velocity - R2 is too low: {score:.2f}") - return np.nan return velocity_above -def get_velocity_below(template, channel_locations, **kwargs): +def get_velocity_below(template, channel_locations, sampling_frequency, **kwargs): """ Compute the velocity below the max channel of the template. @@ -670,55 +696,70 @@ def get_velocity_below(template, channel_locations, **kwargs): **kwargs: Required kwargs: - depth_direction: the direction to compute velocity above and below ("x", "y", or "z") - min_channels_for_velocity: the minimum number of channels above or below to compute velocity - - min_r2_for_velocity: the minimum r2 to accept the velocity fit - - sampling_frequency: the sampling frequency + - min_r2_velocity: the minimum r2 to accept the velocity fit + - same_x: whether to transform the template and channel locations to have the same x coordinate """ assert "depth_direction" in kwargs, "depth_direction must be given as kwarg" assert "min_channels_for_velocity" in kwargs, "min_channels_for_velocity must be given as kwarg" - assert "min_r2_for_velocity" in kwargs, "min_r2_for_velocity must be given as kwarg" - direction = kwargs["depth_direction"] + assert "min_r2_velocity" in kwargs, "min_r2_velocity must be given as kwarg" + assert "same_x" in kwargs, "same_x must be given as kwarg" + + depth_direction = kwargs["depth_direction"] min_channels_for_velocity = kwargs["min_channels_for_velocity"] - min_r2_for_velocity = kwargs["min_r2_for_velocity"] - direction_index = ["x", "y", "z"].index(direction) + min_r2_velocity = kwargs["min_r2_velocity"] + same_x = kwargs["same_x"] + + direction_index = ["x", "y", "z"].index(depth_direction) + template, channel_locations = sort_template_and_locations(template, channel_locations, depth_direction) + + if same_x: + template, channel_locations = transform_same_x(template, channel_locations) # find location of max channel max_sample_idx, max_channel_idx = np.unravel_index(np.argmin(template), template.shape) + max_peak_time = max_sample_idx / sampling_frequency * 1000 max_channel_location = channel_locations[max_channel_idx] - sampling_frequency = kwargs["sampling_frequency"] channels_below = channel_locations[:, direction_index] <= max_channel_location[direction_index] # we only consider samples forward in time with respect to the max channel - template_below = template[max_sample_idx:, channels_below] + # template_below = template[max_sample_idx:, channels_below] + template_below = template[:, channels_below] channel_locations_below = channel_locations[channels_below] - peak_times_ms_below = np.argmin(template_below, 0) / sampling_frequency * 1000 + peak_times_ms_below = np.argmin(template_below, 0) / sampling_frequency * 1000 - max_peak_time distances_um_below = np.array([np.linalg.norm(cl - max_channel_location) for cl in channel_locations_below]) velocity_below, intercept, score = fit_velocity(peak_times_ms_below, distances_um_below) - # if DEBUG: - # fig, ax = plt.subplots() - # ax.plot(peak_times_ms_below, distances_um_below, "o") - # x = np.linspace(peak_times_ms_below.min(), peak_times_ms_below.max(), 20) - # ax.plot(x, intercept + x * velocity_below) - # ax.set_xlabel("Peak time (ms)") - # ax.set_ylabel("Distance from max channel (um)") - # ax.set_title(f"Velocity below: {np.round(velocity_below, 3)} um/ms") - - if np.sum(channels_below) < min_channels_for_velocity: - # if DEBUG: - # ax.set_title("NaN velocity - not enough channels") - return np.nan + global DEBUG + if DEBUG: + import matplotlib.pyplot as plt + + fig, axs = plt.subplots(ncols=2, figsize=(10, 7)) + offset = 1.2 * np.max(np.ptp(template, axis=0)) + ts = np.arange(template.shape[0]) / sampling_frequency * 1000 - max_peak_time + (channel_indices_below,) = np.nonzero(channels_below) + for i, single_template in enumerate(template.T): + color = "r" if i in channel_indices_below else "k" + axs[0].plot(ts, single_template + i * offset, color=color) + axs[0].axvline(0, color="g", ls="--") + axs[1].plot(peak_times_ms_below, distances_um_below, "o") + x = np.linspace(peak_times_ms_below.min(), peak_times_ms_below.max(), 20) + axs[1].plot(x, intercept + x * velocity_below) + axs[1].set_xlabel("Peak time (ms)") + axs[1].set_ylabel("Distance from max channel (um)") + fig.suptitle( + f"Velocity below: {np.round(velocity_below, 3)} um/ms - score {score:.2f} - channels: {np.sum(channels_below)}" + ) + plt.show() - if score < min_r2_for_velocity: - # if DEBUG: - # ax.set_title(f"NaN velocity - R2 is too low: {np.round(score, 3)}") - return np.nan + if np.sum(channels_below) < min_channels_for_velocity or score < min_r2_velocity: + velocity_below = np.nan return velocity_below -def get_exp_decay(template, channel_locations, **kwargs): +def get_exp_decay(template, channel_locations, sampling_frequency=None, **kwargs): """ Compute the exponential decay of the template amplitude over distance. @@ -730,14 +771,18 @@ def get_exp_decay(template, channel_locations, **kwargs): The channel locations (num_channels, 2) **kwargs: Required kwargs: - exp_peak_function: the function to use to compute the peak amplitude for the exp decay ("ptp" or "min") + - min_r2_exp_decay: the minimum r2 to accept the exp decay fit """ from scipy.optimize import curve_fit + from sklearn.metrics import r2_score - def exp_decay(x, a, b, c): - return a * np.exp(-b * x) + c + def exp_decay(x, decay, amp0, offset): + return amp0 * np.exp(-decay * x) + offset assert "exp_peak_function" in kwargs, "exp_peak_function must be given as kwarg" exp_peak_function = kwargs["exp_peak_function"] + assert "min_r2_exp_decay" in kwargs, "min_r2_exp_decay must be given as kwarg" + min_r2_exp_decay = kwargs["min_r2_exp_decay"] # exp decay fit if exp_peak_function == "ptp": fun = np.ptp @@ -747,25 +792,49 @@ def exp_decay(x, a, b, c): max_channel_location = channel_locations[np.argmax(peak_amplitudes)] channel_distances = np.array([np.linalg.norm(cl - max_channel_location) for cl in channel_locations]) distances_sort_indices = np.argsort(channel_distances) - channel_distances_sorted = channel_distances[distances_sort_indices] - peak_amplitudes_sorted = peak_amplitudes[distances_sort_indices] + # np.float128 avoids overflow error + channel_distances_sorted = channel_distances[distances_sort_indices].astype(np.float128) + peak_amplitudes_sorted = peak_amplitudes[distances_sort_indices].astype(np.float128) try: - popt, _ = curve_fit(exp_decay, channel_distances_sorted, peak_amplitudes_sorted) - exp_decay_value = popt[1] - # if DEBUG: - # fig, ax = plt.subplots() - # ax.plot(channel_distances_sorted, peak_amplitudes_sorted, "o") - # x = np.arange(channel_distances_sorted.min(), channel_distances_sorted.max()) - # ax.plot(x, exp_decay(x, *popt)) - # ax.set_xlabel("Distance from max channel (um)") - # ax.set_ylabel("Peak amplitude") - # ax.set_title(f"Exp decay: {np.round(exp_decay_value, 3)}") + amp0 = peak_amplitudes_sorted[0] + offset0 = np.min(peak_amplitudes_sorted) + + popt, _ = curve_fit( + exp_decay, + channel_distances_sorted, + peak_amplitudes_sorted, + bounds=([1e-5, amp0 - 0.5 * amp0, 0], [2, amp0 + 0.5 * amp0, 2 * offset0]), + p0=[1e-3, peak_amplitudes_sorted[0], offset0], + ) + r2 = r2_score(peak_amplitudes_sorted, exp_decay(channel_distances_sorted, *popt)) + exp_decay_value = popt[0] + + if r2 < min_r2_exp_decay: + exp_decay_value = np.nan + + global DEBUG + if DEBUG: + import matplotlib.pyplot as plt + + fig, ax = plt.subplots() + ax.plot(channel_distances_sorted, peak_amplitudes_sorted, "o") + x = np.arange(channel_distances_sorted.min(), channel_distances_sorted.max()) + ax.plot(x, exp_decay(x, *popt)) + ax.set_xlabel("Distance from max channel (um)") + ax.set_ylabel("Peak amplitude") + ax.set_title( + f"Exp decay: {np.round(exp_decay_value, 3)} - Amp: {np.round(popt[1], 3)} - Offset: {np.round(popt[2], 3)} - " + f"R2: {np.round(r2, 4)}" + ) + fig.suptitle("Exp decay") + plt.show() except: exp_decay_value = np.nan + return exp_decay_value -def get_spread(template, channel_locations, **kwargs): +def get_spread(template, channel_locations, sampling_frequency, **kwargs): """ Compute the spread of the template amplitude over distance. @@ -783,23 +852,49 @@ def get_spread(template, channel_locations, **kwargs): depth_direction = kwargs["depth_direction"] assert "spread_threshold" in kwargs, "spread_threshold must be given as kwarg" spread_threshold = kwargs["spread_threshold"] + assert "spread_smooth_um" in kwargs, "spread_smooth_um must be given as kwarg" + spread_smooth_um = kwargs["spread_smooth_um"] + assert "same_x" in kwargs, "same_x must be given as kwarg" + same_x = kwargs["same_x"] direction_index = ["x", "y", "z"].index(depth_direction) + template, channel_locations = sort_template_and_locations(template, channel_locations, depth_direction) + + if same_x: + template, channel_locations = transform_same_x(template, channel_locations) MM = np.ptp(template, 0) MM = MM / np.max(MM) + channel_depths = channel_locations[:, direction_index] + + if spread_smooth_um is not None and spread_smooth_um > 0: + from scipy.ndimage import gaussian_filter1d + + spread_sigma = spread_smooth_um / np.median(np.diff(np.unique(channel_depths))) + MM = gaussian_filter1d(MM, spread_sigma) + channel_locations_above_theshold = channel_locations[MM > spread_threshold] channel_depth_above_theshold = channel_locations_above_theshold[:, direction_index] spread = np.ptp(channel_depth_above_theshold) - # if DEBUG: - # fig, ax = plt.subplots() - # channel_depths = channel_locations[:, direction_index] - # sort_indices = np.argsort(channel_depths) - # ax.plot(channel_depths[sort_indices], MM[sort_indices], "o-") - # ax.axhline(spread_threshold, ls="--", color="r") - # ax.set_xlabel("Depth (um)") - # ax.set_ylabel("Amplitude") - # ax.set_title(f"Spread: {np.round(spread, 3)} um") + global DEBUG + if DEBUG: + import matplotlib.pyplot as plt + + fig, axs = plt.subplots(ncols=2, figsize=(10, 7)) + axs[0].imshow( + template.T, + aspect="auto", + origin="lower", + extent=[0, template.shape[0] / sampling_frequency, channel_depths[0], channel_depths[1]], + ) + axs[1].plot(channel_depths, MM, "o-") + axs[1].axhline(spread_threshold, ls="--", color="r") + axs[1].set_xlabel("Depth (um)") + axs[1].set_ylabel("Amplitude") + axs[1].set_title(f"Spread: {np.round(spread, 3)} um") + fig.suptitle("Spread") + plt.show() + return spread