From c1cd889beacca66f43262f95e18033100f98d59d Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 29 Sep 2023 13:19:35 +0200 Subject: [PATCH] Add 'column_range' and simplify dimension handling --- .../postprocessing/template_metrics.py | 76 +++++++++++-------- 1 file changed, 44 insertions(+), 32 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 090dae4567..774ebab4a9 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -207,7 +207,7 @@ def get_extension_function(): min_r2_exp_decay=0.5, spread_threshold=0.2, spread_smooth_um=20, - same_x=False, + column_range=None, ) @@ -265,7 +265,13 @@ def compute_template_metrics( * min_channels_for_velocity: the minimum number of channels above or below to compute velocity, default: 5 * min_r2_velocity: the minimum r2 to accept the velocity fit, default: 0.7 * exp_peak_function: the function to use to compute the peak amplitude for the exp decay, default: "ptp" + * min_r2_exp_decay: the minimum r2 to accept the exp decay fit, default: 0.5 * spread_threshold: the threshold to compute the spread, default: 0.2 + * spread_smooth_um: the smoothing in um to compute the spread, default: 20 + * column_range: the range in um in the horizontal direction to consider channels for velocity, default: None + - If None, all channels all channels are considered + - If 0 or 1, only the "column" that includes the max channel is considered + - If > 1, only channels within range (+/-) um from the max channel horizontal position are used Returns ------- @@ -278,6 +284,7 @@ def compute_template_metrics( ----- If any multi-channel metric is in the metric_names or include_multi_channel_metrics is True, sparsity must be None, so that one metric value will be computed per unit. + For multi-channel metrocs, 3D channel locations are not supported. By default, the depth direction is "y". """ if debug_plots: global DEBUG @@ -294,6 +301,9 @@ def compute_template_metrics( "If multi-channel metrics are computed, sparsity must be None, " "so that each unit will correspond to 1 row of the output dataframe." ) + assert ( + waveform_extractor.get_channel_locations().shape[1] == 2 + ), "If multi-channel metrics are computed, channel locations must be 2D." default_kwargs = _default_function_kwargs.copy() if metrics_kwargs is None: metrics_kwargs = default_kwargs @@ -579,17 +589,22 @@ def get_num_negative_peaks(template_single, sampling_frequency, **kwargs): # Multi-channel metrics -def transform_same_x(template, channel_locations): - max_channel_x = channel_locations[np.argmax(np.ptp(template, axis=0)), 0] - same_x_mask = channel_locations[:, 0] == max_channel_x - channel_locations_same_x = channel_locations[same_x_mask] - template_same_x = template[:, same_x_mask] - return template_same_x, channel_locations_same_x +def transform_column_range(template, channel_locations, column_range, depth_direction="y"): + column_dim = 0 if depth_direction == "y" else 1 + if column_range is None: + template_column_range = template + channel_locations_column_range = channel_locations + else: + max_channel_x = channel_locations[np.argmax(np.ptp(template, axis=0)), 0] + column_mask = np.abs(channel_locations[:, column_dim] - max_channel_x) <= column_range + template_column_range = template[:, column_mask] + channel_locations_column_range = channel_locations[column_mask] + return template_column_range, channel_locations_column_range def sort_template_and_locations(template, channel_locations, depth_direction="y"): - direction_index = ["x", "y", "z"].index(depth_direction) - sort_indices = np.argsort(channel_locations[:, direction_index]) + depth_dim = 1 if depth_direction == "y" else 0 + sort_indices = np.argsort(channel_locations[:, depth_dim]) return template[:, sort_indices], channel_locations[sort_indices, :] @@ -621,29 +636,28 @@ def get_velocity_above(template, channel_locations, sampling_frequency, **kwargs - depth_direction: the direction to compute velocity above and below ("x", "y", or "z") - min_channels_for_velocity: the minimum number of channels above or below to compute velocity - min_r2_velocity: the minimum r2 to accept the velocity fit + - column_range: the range in um in the x-direction to consider channels for velocity """ assert "depth_direction" in kwargs, "depth_direction must be given as kwarg" assert "min_channels_for_velocity" in kwargs, "min_channels_for_velocity must be given as kwarg" assert "min_r2_velocity" in kwargs, "min_r2_velocity must be given as kwarg" - assert "same_x" in kwargs, "same_x must be given as kwarg" + assert "column_range" in kwargs, "column_range must be given as kwarg" depth_direction = kwargs["depth_direction"] min_channels_for_velocity = kwargs["min_channels_for_velocity"] min_r2_velocity = kwargs["min_r2_velocity"] - same_x = kwargs["same_x"] + column_range = kwargs["column_range"] - direction_index = ["x", "y", "z"].index(depth_direction) + depth_dim = 1 if depth_direction == "y" else 0 + template, channel_locations = transform_column_range(template, channel_locations, column_range, depth_direction) template, channel_locations = sort_template_and_locations(template, channel_locations, depth_direction) - if same_x: - template, channel_locations = transform_same_x(template, channel_locations) - # find location of max channel max_sample_idx, max_channel_idx = np.unravel_index(np.argmin(template), template.shape) max_peak_time = max_sample_idx / sampling_frequency * 1000 max_channel_location = channel_locations[max_channel_idx] - channels_above = channel_locations[:, direction_index] >= max_channel_location[direction_index] + channels_above = channel_locations[:, depth_dim] >= max_channel_location[depth_dim] # we only consider samples forward in time with respect to the max channel # TODO: not sure @@ -697,30 +711,28 @@ def get_velocity_below(template, channel_locations, sampling_frequency, **kwargs - depth_direction: the direction to compute velocity above and below ("x", "y", or "z") - min_channels_for_velocity: the minimum number of channels above or below to compute velocity - min_r2_velocity: the minimum r2 to accept the velocity fit - - same_x: whether to transform the template and channel locations to have the same x coordinate + - column_range: the range in um in the x-direction to consider channels for velocity """ assert "depth_direction" in kwargs, "depth_direction must be given as kwarg" assert "min_channels_for_velocity" in kwargs, "min_channels_for_velocity must be given as kwarg" assert "min_r2_velocity" in kwargs, "min_r2_velocity must be given as kwarg" - assert "same_x" in kwargs, "same_x must be given as kwarg" + assert "column_range" in kwargs, "column_range must be given as kwarg" depth_direction = kwargs["depth_direction"] min_channels_for_velocity = kwargs["min_channels_for_velocity"] min_r2_velocity = kwargs["min_r2_velocity"] - same_x = kwargs["same_x"] + column_range = kwargs["column_range"] - direction_index = ["x", "y", "z"].index(depth_direction) + depth_dim = 1 if depth_direction == "y" else 0 + template, channel_locations = transform_column_range(template, channel_locations, column_range) template, channel_locations = sort_template_and_locations(template, channel_locations, depth_direction) - if same_x: - template, channel_locations = transform_same_x(template, channel_locations) - # find location of max channel max_sample_idx, max_channel_idx = np.unravel_index(np.argmin(template), template.shape) max_peak_time = max_sample_idx / sampling_frequency * 1000 max_channel_location = channel_locations[max_channel_idx] - channels_below = channel_locations[:, direction_index] <= max_channel_location[direction_index] + channels_below = channel_locations[:, depth_dim] <= max_channel_location[depth_dim] # we only consider samples forward in time with respect to the max channel # template_below = template[max_sample_idx:, channels_below] @@ -847,6 +859,7 @@ def get_spread(template, channel_locations, sampling_frequency, **kwargs): **kwargs: Required kwargs: - depth_direction: the direction to compute velocity above and below ("x", "y", or "z") - spread_threshold: the threshold to compute the spread + - column_range: the range in um in the x-direction to consider channels for velocity """ assert "depth_direction" in kwargs, "depth_direction must be given as kwarg" depth_direction = kwargs["depth_direction"] @@ -854,17 +867,16 @@ def get_spread(template, channel_locations, sampling_frequency, **kwargs): spread_threshold = kwargs["spread_threshold"] assert "spread_smooth_um" in kwargs, "spread_smooth_um must be given as kwarg" spread_smooth_um = kwargs["spread_smooth_um"] - assert "same_x" in kwargs, "same_x must be given as kwarg" - same_x = kwargs["same_x"] + assert "column_range" in kwargs, "column_range must be given as kwarg" + column_range = kwargs["column_range"] - direction_index = ["x", "y", "z"].index(depth_direction) + depth_dim = 1 if depth_direction == "y" else 0 + template, channel_locations = transform_column_range(template, channel_locations, column_range) template, channel_locations = sort_template_and_locations(template, channel_locations, depth_direction) - if same_x: - template, channel_locations = transform_same_x(template, channel_locations) MM = np.ptp(template, 0) MM = MM / np.max(MM) - channel_depths = channel_locations[:, direction_index] + channel_depths = channel_locations[:, depth_dim] if spread_smooth_um is not None and spread_smooth_um > 0: from scipy.ndimage import gaussian_filter1d @@ -873,7 +885,7 @@ def get_spread(template, channel_locations, sampling_frequency, **kwargs): MM = gaussian_filter1d(MM, spread_sigma) channel_locations_above_theshold = channel_locations[MM > spread_threshold] - channel_depth_above_theshold = channel_locations_above_theshold[:, direction_index] + channel_depth_above_theshold = channel_locations_above_theshold[:, depth_dim] spread = np.ptp(channel_depth_above_theshold) global DEBUG @@ -885,7 +897,7 @@ def get_spread(template, channel_locations, sampling_frequency, **kwargs): template.T, aspect="auto", origin="lower", - extent=[0, template.shape[0] / sampling_frequency, channel_depths[0], channel_depths[1]], + extent=[0, template.shape[0] / sampling_frequency, channel_depths[0], channel_depths[-1]], ) axs[1].plot(channel_depths, MM, "o-") axs[1].axhline(spread_threshold, ls="--", color="r")