Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Template similarity lags #2941

Merged
merged 46 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from 37 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
5d9ff01
WIP
yger May 30, 2024
fb14e9c
WIP
yger May 30, 2024
22ea58d
Max lag in ms
yger May 30, 2024
f9aaada
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 30, 2024
eb454bd
Merge branch 'SpikeInterface:main' into template_similarity_lags
yger Jun 2, 2024
452f5ac
WIP
yger Jun 3, 2024
303f1dc
Addiing docstrings
yger Jun 3, 2024
6d3daba
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 3, 2024
319cfae
Merge branch 'template_similarity_lags' of github.com:yger/spikeinter…
yger Jun 3, 2024
0360497
WIP
yger Jun 4, 2024
65b2f76
Removing prints
yger Jun 4, 2024
fbb2c1a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 4, 2024
be4ff29
WIP
yger Jun 5, 2024
2389242
Reformatting
yger Jun 5, 2024
6aecac4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 5, 2024
4073663
Fixing test
yger Jun 6, 2024
b458f70
Merge branch 'SpikeInterface:main' into template_similarity_lags
yger Jun 7, 2024
57b7053
Adding supports for template similarities
yger Jun 8, 2024
7c3e603
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 8, 2024
14191db
Addition of lussac metrics
yger Jun 9, 2024
6c04eb5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 9, 2024
dfdefea
Merge branch 'main' into template_similarity_lags
yger Jun 9, 2024
d13587a
Some optimizations and docs
yger Jun 9, 2024
be88b6f
Some optimizations and docs
yger Jun 9, 2024
69b607f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 9, 2024
657b551
Oups
yger Jun 9, 2024
5f2d47e
Merge branch 'template_similarity_lags' of github.com:yger/spikeinter…
yger Jun 9, 2024
9e930ea
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 9, 2024
009dacc
Adding the method union_if_intersection
yger Jun 9, 2024
c4265f9
Merge branch 'template_similarity_lags' of github.com:yger/spikeinter…
yger Jun 9, 2024
e399510
Fixes for normalized norms
yger Jun 9, 2024
29bab0e
Fix
yger Jun 9, 2024
8853754
Renaming the metrics
yger Jun 10, 2024
4397a96
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 10, 2024
db20802
Fixing bugs
yger Jun 10, 2024
f5dac0f
Merge branch 'template_similarity_lags' of github.com:yger/spikeinter…
yger Jun 10, 2024
2be8ba7
Docs
yger Jun 12, 2024
5afcf9b
union as default
yger Jun 13, 2024
430dae2
Fixing tests
yger Jun 19, 2024
0fffcf7
Merge branch 'main' into template_similarity_lags
yger Jun 24, 2024
875254c
Merge branch 'main' into template_similarity_lags
yger Jun 25, 2024
bcec1d4
Update src/spikeinterface/postprocessing/template_similarity.py
yger Jun 25, 2024
ca8f517
Update src/spikeinterface/postprocessing/template_similarity.py
yger Jun 25, 2024
914dabf
Move sparsity/mask logic to compute_similarity_with_templates_array a…
alejoe91 Jun 25, 2024
572ddf9
Merge branch 'template_similarity_lags' of github.com:yger/spikeinter…
alejoe91 Jun 25, 2024
d707b69
Sam's suggestions: renaming and tests
alejoe91 Jun 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 118 additions & 13 deletions src/spikeinterface/postprocessing/template_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,26 @@
class ComputeTemplateSimilarity(AnalyzerExtension):
"""Compute similarity between templates with several methods.

Similarity is defined as 1 - distance(T_1, T_2) for two templates T_1, T_2


Parameters
----------
sorting_analyzer: SortingAnalyzer
sorting_analyzer : SortingAnalyzer
The SortingAnalyzer object
method: str, default: "cosine_similarity"
The method to compute the similarity
method : str, default: "cosine"
The method to compute the similarity. Can be in ["cosine", "l2", "l1"]
max_lag_ms : float, default 0
alejoe91 marked this conversation as resolved.
Show resolved Hide resolved
If specified, the best distance for all given lag within max_lag_ms is kept, for every template
support : str, default "dense"
Support that should be considered to compute the distances between the templates, given their sparsities.
Can be either ["dense", "union", "intersection"]
alejoe91 marked this conversation as resolved.
Show resolved Hide resolved

In case of "l1" or "l2", the formula used is:
similarity = 1 - norm(T_1 - T_2)/(norm(T_1) + norm(T_2))

In case of cosine this is:
similarity = 1 - sum(T_1.T_2)/(norm(T_1)norm(T_2))

Returns
-------
Expand All @@ -32,8 +45,8 @@ class ComputeTemplateSimilarity(AnalyzerExtension):
def __init__(self, sorting_analyzer):
AnalyzerExtension.__init__(self, sorting_analyzer)

def _set_params(self, method="cosine_similarity"):
params = dict(method=method)
def _set_params(self, method="cosine", max_lag_ms=0, support="dense"):
yger marked this conversation as resolved.
Show resolved Hide resolved
params = dict(method=method, max_lag_ms=max_lag_ms, support=support)
alejoe91 marked this conversation as resolved.
Show resolved Hide resolved
return params

def _select_extension_data(self, unit_ids):
Expand All @@ -43,11 +56,23 @@ def _select_extension_data(self, unit_ids):
return dict(similarity=new_similarity)

def _run(self, verbose=False):
n_shifts = int(self.params["max_lag_ms"] * self.sorting_analyzer.sampling_frequency / 1000)
templates_array = get_dense_templates_array(
self.sorting_analyzer, return_scaled=self.sorting_analyzer.return_scaled
)
sparsity = self.sorting_analyzer.sparsity
mask = None
if sparsity is not None:
if self.params["support"] == "intersection":
mask = np.logical_and(sparsity.mask[:, np.newaxis, :], sparsity.mask[np.newaxis, :, :])
elif self.params["support"] == "union":
mask = np.logical_and(sparsity.mask[:, np.newaxis, :], sparsity.mask[np.newaxis, :, :])
units_overlaps = np.sum(mask, axis=2) > 0
mask = np.logical_or(sparsity.mask[:, np.newaxis, :], sparsity.mask[np.newaxis, :, :])
mask[~units_overlaps] = False

similarity = compute_similarity_with_templates_array(
templates_array, templates_array, method=self.params["method"]
templates_array, templates_array, method=self.params["method"], n_shifts=n_shifts, mask=mask
)
self.data["similarity"] = similarity

Expand All @@ -60,25 +85,105 @@ def _get_data(self):
compute_template_similarity = ComputeTemplateSimilarity.function_factory()


def compute_similarity_with_templates_array(templates_array, other_templates_array, method):
def compute_similarity_with_templates_array(templates_array, other_templates_array, method, n_shifts, mask=None):

import sklearn.metrics.pairwise

if method == "cosine_similarity":
method = "cosine"
alejoe91 marked this conversation as resolved.
Show resolved Hide resolved

all_metrics = ["cosine", "l1", "l2"]

if method in all_metrics:
nb_templates = templates_array.shape[0]
assert templates_array.shape[0] == other_templates_array.shape[0]
templates_flat = templates_array.reshape(templates_array.shape[0], -1)
other_templates_flat = templates_array.reshape(other_templates_array.shape[0], -1)
similarity = sklearn.metrics.pairwise.cosine_similarity(templates_flat, other_templates_flat)
n = templates_array.shape[1]
nb_templates = templates_array.shape[0]
assert n_shifts < n, "max_lag is too large"
num_shifts = 2 * n_shifts + 1
distances = np.ones((num_shifts, nb_templates, nb_templates), dtype=np.float32)
if mask is not None:
units_overlaps = np.sum(mask, axis=2) > 0
overlapping_templates = {}
for i in range(nb_templates):
overlapping_templates[i] = np.flatnonzero(units_overlaps[i])

# We can use the fact that dist[i,j] at lag t is equal to dist[j,i] at time -t
# So the matrix can be computed only for negative lags and be transposed

for count, shift in enumerate(range(-n_shifts, 1)):
if mask is None:
alejoe91 marked this conversation as resolved.
Show resolved Hide resolved
src_templates = templates_array[:, n_shifts : n - n_shifts].reshape(nb_templates, -1)
tgt_templates = templates_array[:, n_shifts + shift : n - n_shifts + shift].reshape(nb_templates, -1)
if method == "l1":
norms_1 = np.linalg.norm(src_templates, ord=1, axis=1)
norms_2 = np.linalg.norm(tgt_templates, ord=1, axis=1)
denominator = norms_1[:, None] + norms_2[None, :]
distances[count] = sklearn.metrics.pairwise.pairwise_distances(
src_templates, tgt_templates, metric="l1"
)
distances[count] /= denominator
elif method == "l2":
norms_1 = np.linalg.norm(src_templates, ord=2, axis=1)
norms_2 = np.linalg.norm(tgt_templates, ord=2, axis=1)
denominator = norms_1[:, None] + norms_2[None, :]
distances[count] = sklearn.metrics.pairwise.pairwise_distances(
src_templates, tgt_templates, metric="l2"
)
distances[count] /= denominator
else:
distances[count] = sklearn.metrics.pairwise.pairwise_distances(
src_templates, tgt_templates, metric=method
)
alejoe91 marked this conversation as resolved.
Show resolved Hide resolved

if n_shifts != 0:
distances[num_shifts - count - 1] = distances[count].T

else:
src_sliced_templates = templates_array[:, n_shifts : n - n_shifts]
tgt_sliced_templates = templates_array[:, n_shifts + shift : n - n_shifts + shift]
for i in range(nb_templates):
src_template = src_sliced_templates[i]
tgt_templates = tgt_sliced_templates[overlapping_templates[i]]
for gcount, j in enumerate(overlapping_templates[i]):
if j < i:
continue
src = src_template[:, mask[i, j]].reshape(1, -1)
tgt = (tgt_templates[gcount][:, mask[i, j]]).reshape(1, -1)

if method == "l1":
norm_i = np.sum(np.abs(src))
norm_j = np.sum(np.abs(tgt))
distances[count, i, j] = sklearn.metrics.pairwise.pairwise_distances(src, tgt, metric="l1")
distances[count, i, j] /= norm_i + norm_j
elif method == "l2":
norm_i = np.linalg.norm(src, ord=2)
norm_j = np.linalg.norm(tgt, ord=2)
distances[count, i, j] = sklearn.metrics.pairwise.pairwise_distances(src, tgt, metric="l2")
distances[count, i, j] /= norm_i + norm_j
else:
distances[count, i, j] = sklearn.metrics.pairwise.pairwise_distances(
src, tgt, metric=method
)

distances[count, j, i] = distances[count, i, j]

if n_shifts != 0:
distances[num_shifts - count - 1] = distances[count].T

distances = np.min(distances, axis=0)
similarity = 1 - distances

else:
raise ValueError(f"compute_template_similarity(method {method}) not exists")
raise ValueError(f"compute_template_similarity (method {method}) not exists")

return similarity


def compute_template_similarity_by_pair(sorting_analyzer_1, sorting_analyzer_2, method="cosine_similarity"):
def compute_template_similarity_by_pair(sorting_analyzer_1, sorting_analyzer_2, method="cosine", **kwargs):
templates_array_1 = get_dense_templates_array(sorting_analyzer_1, return_scaled=True)
templates_array_2 = get_dense_templates_array(sorting_analyzer_2, return_scaled=True)
similarity = compute_similarity_with_templates_array(templates_array_1, templates_array_2, method)
similarity = compute_similarity_with_templates_array(templates_array_1, templates_array_2, method, **kwargs)
return similarity


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
class SimilarityExtensionTest(AnalyzerExtensionCommonTestSuite, unittest.TestCase):
extension_class = ComputeTemplateSimilarity
extension_function_params_list = [
dict(method="cosine_similarity"),
dict(method="cosine"),
dict(method="cosine", max_lag_ms=0.5),
dict(method="l2"),
dict(method="l1"),
]

def test_check_equal_template_with_distribution_overlap(self):
Expand Down