Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Apr 24, 2024
1 parent fd52eee commit 3986b52
Showing 1 changed file with 30 additions and 24 deletions.
54 changes: 30 additions & 24 deletions src/spikeinterface/curation/auto_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from .mergeunitssorting import MergeUnitsSorting


def get_potential_auto_merge(
sorting_analyzer,
minimum_spikes=1000,
Expand All @@ -30,7 +31,7 @@ def get_potential_auto_merge(
firing_contamination_balance=1.5,
extra_outputs=False,
steps=None,
template_metric='l1'
template_metric="l1",
):
"""
Algorithm to find and check potential merges between units.
Expand Down Expand Up @@ -144,7 +145,7 @@ def get_potential_auto_merge(
to_remove = num_spikes < minimum_spikes
pair_mask[to_remove, :] = False
pair_mask[:, to_remove] = False

# STEP 2 : remove contaminated auto corr
if "remove_contaminated" in steps:
contaminations, nb_violations = compute_refrac_period_violations(
Expand All @@ -168,10 +169,10 @@ def get_potential_auto_merge(
)
unit_max_chan = list(unit_max_chan.values())
unit_locations = chan_loc[unit_max_chan, :]

unit_distances = scipy.spatial.distance.cdist(unit_locations, unit_locations, metric="euclidean")
pair_mask = pair_mask & (unit_distances <= maximum_distance_um)

# STEP 4 : potential auto merge by correlogram
if "correlogram" in steps:
correlograms, bins = compute_correlograms(sorting, window_ms=window_ms, bin_ms=bin_ms, method="numba")
Expand Down Expand Up @@ -203,14 +204,18 @@ def get_potential_auto_merge(
templates_ext is not None
), "auto_merge with template_similarity requires a SortingAnalyzer with extension templates"

templates = templates_ext.get_data(outputs='Templates')
templates = templates_ext.get_data(outputs="Templates")
templates = templates.to_sparse(sorting_analyzer.sparsity)

templates_diff = compute_templates_diff(
sorting, templates, num_channels=num_channels, num_shift=num_shift, pair_mask=pair_mask,
template_metric=template_metric
sorting,
templates,
num_channels=num_channels,
num_shift=num_shift,
pair_mask=pair_mask,
template_metric=template_metric,
)

pair_mask = pair_mask & (templates_diff < template_diff_thresh)

# STEP 6 : validate the potential merges with CC increase the contamination quality metrics
Expand Down Expand Up @@ -391,7 +396,7 @@ def get_unit_adaptive_window(auto_corr: np.ndarray, threshold: float):
return win_size


def compute_templates_diff(sorting, templates, num_channels=5, num_shift=5, pair_mask=None, template_metric='l1'):
def compute_templates_diff(sorting, templates, num_channels=5, num_shift=5, pair_mask=None, template_metric="l1"):
"""
Computes normalized template differences.
Expand Down Expand Up @@ -446,25 +451,25 @@ def compute_templates_diff(sorting, templates, num_channels=5, num_shift=5, pair
template2 = template2[:, chan_inds]

num_samples = template1.shape[0]
if template_metric == 'l1':
if template_metric == "l1":
norm = np.sum(np.abs(template1)) + np.sum(np.abs(template2))
elif template_metric == 'l2':
elif template_metric == "l2":
norm = np.sum(template1**2) + np.sum(template2**2)
elif template_metric == 'cosine':
elif template_metric == "cosine":
norm = np.linalg.norm(template1) * np.linalg.norm(template2)
all_shift_diff = []
for shift in range(-num_shift, num_shift + 1):
temp1 = template1[num_shift : num_samples - num_shift, :]
temp2 = template2[num_shift + shift : num_samples - num_shift + shift, :]
if template_metric == 'l1':
if template_metric == "l1":
d = np.sum(np.abs(temp1 - temp2)) / norm
elif template_metric == 'l2':
elif template_metric == "l2":
d = np.linalg.norm(temp1 - temp2) / norm
elif template_metric == 'cosine':
elif template_metric == "cosine":
d = min(1, 1 - np.sum(temp1 * temp2) / norm)
all_shift_diff.append(d)
templates_diff[unit_ind1, unit_ind2] = np.min(all_shift_diff)

return templates_diff


Expand Down Expand Up @@ -525,9 +530,10 @@ def check_improve_contaminations_score(

return pair_mask, pairs_removed

# def apply_potential_merges(sorting_analyzer, potential_merges,
# firing_contamination_balance=1.5,
# refractory_period_ms=1,

# def apply_potential_merges(sorting_analyzer, potential_merges,
# firing_contamination_balance=1.5,
# refractory_period_ms=1,
# censored_period_ms=0.3):

# contaminations, nb_violations = compute_refrac_period_violations(
Expand All @@ -544,7 +550,7 @@ def check_improve_contaminations_score(
# unit1, unit2 = potential_merge
# if unit1 not in sorting.unit_ids or unit2 not in sorting.unit_ids:
# continue

# for unit in [unit1, unit2]:
# if unit not in graph:
# k = 1 + firing_contamination_balance
Expand All @@ -561,7 +567,7 @@ def check_improve_contaminations_score(
# merge = None

# for unit1, unit2 in subgraph.edges:

# # make a merged sorting and tale one unit (unit_id1 is used)
# sorting_merged = MergeUnitsSorting(
# sorting, [[unit1, unit2]], new_unit_ids=[unit1], delta_time_ms=censored_period_ms
Expand All @@ -579,7 +585,7 @@ def check_improve_contaminations_score(
# if score > highest_score:
# highest_score = score
# merge = (unit1, unit2)

# if merge is None:
# scores = dict(subgraph.nodes(data="score"))
# best_unit = max(scores, key=scores.get)
Expand All @@ -594,4 +600,4 @@ def check_improve_contaminations_score(
# )
# subgraph = nx.contracted_nodes(subgraph, unit1, unit2, self_loops=False)
# subgraph.nodes[unit1]['score'] = highest_score
# return sorting
# return sorting

0 comments on commit 3986b52

Please sign in to comment.