Skip to content

Commit

Permalink
Add 'column_range' and simplify dimension handling
Browse files Browse the repository at this point in the history
  • Loading branch information
alejoe91 committed Sep 29, 2023
1 parent 7ba84ad commit c1cd889
Showing 1 changed file with 44 additions and 32 deletions.
76 changes: 44 additions & 32 deletions src/spikeinterface/postprocessing/template_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def get_extension_function():
min_r2_exp_decay=0.5,
spread_threshold=0.2,
spread_smooth_um=20,
same_x=False,
column_range=None,
)


Expand Down Expand Up @@ -265,7 +265,13 @@ def compute_template_metrics(
* min_channels_for_velocity: the minimum number of channels above or below to compute velocity, default: 5
* min_r2_velocity: the minimum r2 to accept the velocity fit, default: 0.7
* exp_peak_function: the function to use to compute the peak amplitude for the exp decay, default: "ptp"
* min_r2_exp_decay: the minimum r2 to accept the exp decay fit, default: 0.5
* spread_threshold: the threshold to compute the spread, default: 0.2
* spread_smooth_um: the smoothing in um to compute the spread, default: 20
* column_range: the range in um in the horizontal direction to consider channels for velocity, default: None
- If None, all channels all channels are considered
- If 0 or 1, only the "column" that includes the max channel is considered
- If > 1, only channels within range (+/-) um from the max channel horizontal position are used
Returns
-------
Expand All @@ -278,6 +284,7 @@ def compute_template_metrics(
-----
If any multi-channel metric is in the metric_names or include_multi_channel_metrics is True, sparsity must be None,
so that one metric value will be computed per unit.
For multi-channel metrocs, 3D channel locations are not supported. By default, the depth direction is "y".
"""
if debug_plots:
global DEBUG
Expand All @@ -294,6 +301,9 @@ def compute_template_metrics(
"If multi-channel metrics are computed, sparsity must be None, "
"so that each unit will correspond to 1 row of the output dataframe."
)
assert (
waveform_extractor.get_channel_locations().shape[1] == 2
), "If multi-channel metrics are computed, channel locations must be 2D."
default_kwargs = _default_function_kwargs.copy()
if metrics_kwargs is None:
metrics_kwargs = default_kwargs
Expand Down Expand Up @@ -579,17 +589,22 @@ def get_num_negative_peaks(template_single, sampling_frequency, **kwargs):
# Multi-channel metrics


def transform_same_x(template, channel_locations):
max_channel_x = channel_locations[np.argmax(np.ptp(template, axis=0)), 0]
same_x_mask = channel_locations[:, 0] == max_channel_x
channel_locations_same_x = channel_locations[same_x_mask]
template_same_x = template[:, same_x_mask]
return template_same_x, channel_locations_same_x
def transform_column_range(template, channel_locations, column_range, depth_direction="y"):
column_dim = 0 if depth_direction == "y" else 1
if column_range is None:
template_column_range = template
channel_locations_column_range = channel_locations
else:
max_channel_x = channel_locations[np.argmax(np.ptp(template, axis=0)), 0]
column_mask = np.abs(channel_locations[:, column_dim] - max_channel_x) <= column_range
template_column_range = template[:, column_mask]
channel_locations_column_range = channel_locations[column_mask]
return template_column_range, channel_locations_column_range


def sort_template_and_locations(template, channel_locations, depth_direction="y"):
direction_index = ["x", "y", "z"].index(depth_direction)
sort_indices = np.argsort(channel_locations[:, direction_index])
depth_dim = 1 if depth_direction == "y" else 0
sort_indices = np.argsort(channel_locations[:, depth_dim])
return template[:, sort_indices], channel_locations[sort_indices, :]


Expand Down Expand Up @@ -621,29 +636,28 @@ def get_velocity_above(template, channel_locations, sampling_frequency, **kwargs
- depth_direction: the direction to compute velocity above and below ("x", "y", or "z")
- min_channels_for_velocity: the minimum number of channels above or below to compute velocity
- min_r2_velocity: the minimum r2 to accept the velocity fit
- column_range: the range in um in the x-direction to consider channels for velocity
"""
assert "depth_direction" in kwargs, "depth_direction must be given as kwarg"
assert "min_channels_for_velocity" in kwargs, "min_channels_for_velocity must be given as kwarg"
assert "min_r2_velocity" in kwargs, "min_r2_velocity must be given as kwarg"
assert "same_x" in kwargs, "same_x must be given as kwarg"
assert "column_range" in kwargs, "column_range must be given as kwarg"

depth_direction = kwargs["depth_direction"]
min_channels_for_velocity = kwargs["min_channels_for_velocity"]
min_r2_velocity = kwargs["min_r2_velocity"]
same_x = kwargs["same_x"]
column_range = kwargs["column_range"]

direction_index = ["x", "y", "z"].index(depth_direction)
depth_dim = 1 if depth_direction == "y" else 0
template, channel_locations = transform_column_range(template, channel_locations, column_range, depth_direction)
template, channel_locations = sort_template_and_locations(template, channel_locations, depth_direction)

if same_x:
template, channel_locations = transform_same_x(template, channel_locations)

# find location of max channel
max_sample_idx, max_channel_idx = np.unravel_index(np.argmin(template), template.shape)
max_peak_time = max_sample_idx / sampling_frequency * 1000
max_channel_location = channel_locations[max_channel_idx]

channels_above = channel_locations[:, direction_index] >= max_channel_location[direction_index]
channels_above = channel_locations[:, depth_dim] >= max_channel_location[depth_dim]

# we only consider samples forward in time with respect to the max channel
# TODO: not sure
Expand Down Expand Up @@ -697,30 +711,28 @@ def get_velocity_below(template, channel_locations, sampling_frequency, **kwargs
- depth_direction: the direction to compute velocity above and below ("x", "y", or "z")
- min_channels_for_velocity: the minimum number of channels above or below to compute velocity
- min_r2_velocity: the minimum r2 to accept the velocity fit
- same_x: whether to transform the template and channel locations to have the same x coordinate
- column_range: the range in um in the x-direction to consider channels for velocity
"""
assert "depth_direction" in kwargs, "depth_direction must be given as kwarg"
assert "min_channels_for_velocity" in kwargs, "min_channels_for_velocity must be given as kwarg"
assert "min_r2_velocity" in kwargs, "min_r2_velocity must be given as kwarg"
assert "same_x" in kwargs, "same_x must be given as kwarg"
assert "column_range" in kwargs, "column_range must be given as kwarg"

depth_direction = kwargs["depth_direction"]
min_channels_for_velocity = kwargs["min_channels_for_velocity"]
min_r2_velocity = kwargs["min_r2_velocity"]
same_x = kwargs["same_x"]
column_range = kwargs["column_range"]

direction_index = ["x", "y", "z"].index(depth_direction)
depth_dim = 1 if depth_direction == "y" else 0
template, channel_locations = transform_column_range(template, channel_locations, column_range)
template, channel_locations = sort_template_and_locations(template, channel_locations, depth_direction)

if same_x:
template, channel_locations = transform_same_x(template, channel_locations)

# find location of max channel
max_sample_idx, max_channel_idx = np.unravel_index(np.argmin(template), template.shape)
max_peak_time = max_sample_idx / sampling_frequency * 1000
max_channel_location = channel_locations[max_channel_idx]

channels_below = channel_locations[:, direction_index] <= max_channel_location[direction_index]
channels_below = channel_locations[:, depth_dim] <= max_channel_location[depth_dim]

# we only consider samples forward in time with respect to the max channel
# template_below = template[max_sample_idx:, channels_below]
Expand Down Expand Up @@ -847,24 +859,24 @@ def get_spread(template, channel_locations, sampling_frequency, **kwargs):
**kwargs: Required kwargs:
- depth_direction: the direction to compute velocity above and below ("x", "y", or "z")
- spread_threshold: the threshold to compute the spread
- column_range: the range in um in the x-direction to consider channels for velocity
"""
assert "depth_direction" in kwargs, "depth_direction must be given as kwarg"
depth_direction = kwargs["depth_direction"]
assert "spread_threshold" in kwargs, "spread_threshold must be given as kwarg"
spread_threshold = kwargs["spread_threshold"]
assert "spread_smooth_um" in kwargs, "spread_smooth_um must be given as kwarg"
spread_smooth_um = kwargs["spread_smooth_um"]
assert "same_x" in kwargs, "same_x must be given as kwarg"
same_x = kwargs["same_x"]
assert "column_range" in kwargs, "column_range must be given as kwarg"
column_range = kwargs["column_range"]

direction_index = ["x", "y", "z"].index(depth_direction)
depth_dim = 1 if depth_direction == "y" else 0
template, channel_locations = transform_column_range(template, channel_locations, column_range)
template, channel_locations = sort_template_and_locations(template, channel_locations, depth_direction)

if same_x:
template, channel_locations = transform_same_x(template, channel_locations)
MM = np.ptp(template, 0)
MM = MM / np.max(MM)
channel_depths = channel_locations[:, direction_index]
channel_depths = channel_locations[:, depth_dim]

if spread_smooth_um is not None and spread_smooth_um > 0:
from scipy.ndimage import gaussian_filter1d
Expand All @@ -873,7 +885,7 @@ def get_spread(template, channel_locations, sampling_frequency, **kwargs):
MM = gaussian_filter1d(MM, spread_sigma)

channel_locations_above_theshold = channel_locations[MM > spread_threshold]
channel_depth_above_theshold = channel_locations_above_theshold[:, direction_index]
channel_depth_above_theshold = channel_locations_above_theshold[:, depth_dim]
spread = np.ptp(channel_depth_above_theshold)

global DEBUG
Expand All @@ -885,7 +897,7 @@ def get_spread(template, channel_locations, sampling_frequency, **kwargs):
template.T,
aspect="auto",
origin="lower",
extent=[0, template.shape[0] / sampling_frequency, channel_depths[0], channel_depths[1]],
extent=[0, template.shape[0] / sampling_frequency, channel_depths[0], channel_depths[-1]],
)
axs[1].plot(channel_depths, MM, "o-")
axs[1].axhline(spread_threshold, ls="--", color="r")
Expand Down

0 comments on commit c1cd889

Please sign in to comment.