Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhancing curation : get_potential_auto_merge() #2753

Merged
merged 11 commits into from
Apr 30, 2024
86 changes: 67 additions & 19 deletions src/spikeinterface/curation/auto_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np

from ..core import create_sorting_analyzer
from ..core.template import Templates
from ..core.template_tools import get_template_extremum_channel
from ..postprocessing import compute_correlograms
from ..qualitymetrics import compute_refrac_period_violations, compute_firing_rates
Expand Down Expand Up @@ -30,6 +31,7 @@ def get_potential_auto_merge(
firing_contamination_balance=1.5,
extra_outputs=False,
steps=None,
template_metric="l1",
):
"""
Algorithm to find and check potential merges between units.
Expand Down Expand Up @@ -63,7 +65,7 @@ def get_potential_auto_merge(
Minimum number of spikes for each unit to consider a potential merge.
Enough spikes are needed to estimate the correlogram
maximum_distance_um: float, default: 150
Minimum distance between units for considering a merge
Maximum distance between units for considering a merge
peak_sign: "neg" | "pos" | "both", default: "neg"
Peak sign used to estimate the maximum channel of a template
bin_ms: float, default: 0.25
Expand Down Expand Up @@ -101,6 +103,8 @@ def get_potential_auto_merge(
If None all steps are done.
Pontential steps: "min_spikes", "remove_contaminated", "unit_positions", "correlogram", "template_similarity",
"check_increase_score". Please check steps explanations above!
template_metric: 'l1', 'l2' or 'cosine'
The metric to consider when measuring the distances between templates. Default is l1

Returns
-------
Expand All @@ -114,6 +118,7 @@ def get_potential_auto_merge(
import scipy

sorting = sorting_analyzer.sorting
recording = sorting_analyzer.recording
unit_ids = sorting.unit_ids

# to get fast computation we will not analyse pairs when:
Expand Down Expand Up @@ -154,12 +159,17 @@ def get_potential_auto_merge(

# STEP 3 : unit positions are estimated roughly with channel
if "unit_positions" in steps:
chan_loc = sorting_analyzer.get_channel_locations()
unit_max_chan = get_template_extremum_channel(
sorting_analyzer, peak_sign=peak_sign, mode="extremum", outputs="index"
)
unit_max_chan = list(unit_max_chan.values())
unit_locations = chan_loc[unit_max_chan, :]
positions_ext = sorting_analyzer.get_extension("unit_locations")
if positions_ext is not None:
unit_locations = positions_ext.get_data()[:, :2]
else:
chan_loc = sorting_analyzer.get_channel_locations()
unit_max_chan = get_template_extremum_channel(
sorting_analyzer, peak_sign=peak_sign, mode="extremum", outputs="index"
)
unit_max_chan = list(unit_max_chan.values())
unit_locations = chan_loc[unit_max_chan, :]

unit_distances = scipy.spatial.distance.cdist(unit_locations, unit_locations, metric="euclidean")
pair_mask = pair_mask & (unit_distances <= maximum_distance_um)

Expand Down Expand Up @@ -194,10 +204,18 @@ def get_potential_auto_merge(
templates_ext is not None
), "auto_merge with template_similarity requires a SortingAnalyzer with extension templates"

templates = templates_ext.get_templates(operator="average")
templates_array = templates_ext.get_data(outputs="numpy")

templates_diff = compute_templates_diff(
sorting, templates, num_channels=num_channels, num_shift=num_shift, pair_mask=pair_mask
sorting,
templates_array,
num_channels=num_channels,
num_shift=num_shift,
pair_mask=pair_mask,
template_metric=template_metric,
sparsity=sorting_analyzer.sparsity,
)

pair_mask = pair_mask & (templates_diff < template_diff_thresh)

# STEP 6 : validate the potential merges with CC increase the contamination quality metrics
Expand Down Expand Up @@ -378,23 +396,29 @@ def get_unit_adaptive_window(auto_corr: np.ndarray, threshold: float):
return win_size


def compute_templates_diff(sorting, templates, num_channels=5, num_shift=5, pair_mask=None):
def compute_templates_diff(
sorting, templates_array, num_channels=5, num_shift=5, pair_mask=None, template_metric="l1", sparsity=None
):
"""
Computes normalilzed template differences.
Computes normalized template differences.

Parameters
----------
sorting : BaseSorting
The sorting object
templates : np.array
The templates array (num_units, num_samples, num_channels)
templates_array : np.array
The templates array (num_units, num_samples, num_channels).
num_channels: int, default: 5
Number of channel to use for template similarity computation
num_shift: int, default: 5
Number of shifts in samles to be explored for template similarity computation
pair_mask: None or boolean array
A bool matrix of size (num_units, num_units) to select
which pair to compute.
template_metric: 'l1', 'l2' or 'cosine'
The metric to consider when measuring the distances between templates. Default is l1
sparsity: None or ChannelSparsity
Optionaly a ChannelSparsity object.

Returns
-------
Expand All @@ -403,30 +427,54 @@ def compute_templates_diff(sorting, templates, num_channels=5, num_shift=5, pair
"""
unit_ids = sorting.unit_ids
n = len(unit_ids)
assert template_metric in ["l1", "l2", "cosine"], "Not a valid metric!"

if pair_mask is None:
pair_mask = np.ones((n, n), dtype="bool")

if sparsity is None:
adaptative_masks = False
sparsity_mask = None
else:
adaptative_masks = num_channels == None
sparsity_mask = sparsity.mask

templates_diff = np.full((n, n), np.nan, dtype="float64")
for unit_ind1 in range(n):
for unit_ind2 in range(unit_ind1 + 1, n):
if not pair_mask[unit_ind1, unit_ind2]:
continue

template1 = templates[unit_ind1]
template2 = templates[unit_ind2]
template1 = templates_array[unit_ind1]
template2 = templates_array[unit_ind2]
# take best channels
chan_inds = np.argsort(np.max(np.abs(template1 + template2), axis=0))[::-1][:num_channels]
if not adaptative_masks:
chan_inds = np.argsort(np.max(np.abs(template1 + template2), axis=0))[::-1][:num_channels]
else:
chan_inds = np.intersect1d(
np.flatnonzero(sparsity_mask[unit_ind1]), np.flatnonzero(sparsity_mask[unit_ind2])
)

template1 = template1[:, chan_inds]
template2 = template2[:, chan_inds]

num_samples = template1.shape[0]
norm = np.sum(np.abs(template1)) + np.sum(np.abs(template2))
if template_metric == "l1":
norm = np.sum(np.abs(template1)) + np.sum(np.abs(template2))
elif template_metric == "l2":
norm = np.sum(template1**2) + np.sum(template2**2)
elif template_metric == "cosine":
norm = np.linalg.norm(template1) * np.linalg.norm(template2)
all_shift_diff = []
for shift in range(-num_shift, num_shift + 1):
temp1 = template1[num_shift : num_samples - num_shift, :]
temp2 = template2[num_shift + shift : num_samples - num_shift + shift, :]
d = np.sum(np.abs(temp1 - temp2)) / (norm)
if template_metric == "l1":
d = np.sum(np.abs(temp1 - temp2)) / norm
elif template_metric == "l2":
d = np.linalg.norm(temp1 - temp2) / norm
elif template_metric == "cosine":
d = 1 - np.sum(temp1 * temp2) / norm
all_shift_diff.append(d)
templates_diff[unit_ind1, unit_ind2] = np.min(all_shift_diff)

Expand All @@ -437,7 +485,7 @@ def check_improve_contaminations_score(
sorting_analyzer, pair_mask, contaminations, firing_contamination_balance, refractory_period_ms, censored_period_ms
):
"""
Check that the score is improve afeter a potential merge
Check that the score is improve after a potential merge

The score is a balance between:
* contamination decrease
Expand Down
Loading