Skip to content

Commit

Permalink
WIP for components in SC2, following Benchmarks (#2900)
Browse files Browse the repository at this point in the history
SC2 improvements
  • Loading branch information
yger authored May 24, 2024
1 parent 0df2536 commit 95d2917
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 64 deletions.
26 changes: 18 additions & 8 deletions src/spikeinterface/sortingcomponents/clustering/circus.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class CircusClustering:
"n_svd": [5, 2],
"ms_before": 0.5,
"ms_after": 0.5,
"rank": 5,
"noise_levels": None,
"tmp_folder": None,
"job_kwargs": {},
Expand Down Expand Up @@ -208,16 +209,17 @@ def main_function(cls, recording, peaks, params):
**job_kwargs,
)

labels, inverse = np.unique(peak_labels[peak_labels > -1], return_inverse=True)
labels = np.arange(len(labels))
non_noise = peak_labels > -1
labels, inverse = np.unique(peak_labels[non_noise], return_inverse=True)
peak_labels[non_noise] = inverse
labels = np.unique(inverse)

spikes = np.zeros(np.sum(peak_labels > -1), dtype=minimum_spike_dtype)
mask = peak_labels > -1
spikes["sample_index"] = peaks[mask]["sample_index"]
spikes["segment_index"] = peaks[mask]["segment_index"]
spikes["unit_index"] = inverse
spikes = np.zeros(non_noise.sum(), dtype=minimum_spike_dtype)
spikes["sample_index"] = peaks[non_noise]["sample_index"]
spikes["segment_index"] = peaks[non_noise]["segment_index"]
spikes["unit_index"] = peak_labels[non_noise]

unit_ids = np.arange(len(np.unique(spikes["unit_index"])))
unit_ids = labels

nbefore = int(params["waveforms"]["ms_before"] * fs / 1000.0)
nafter = int(params["waveforms"]["ms_after"] * fs / 1000.0)
Expand All @@ -226,6 +228,11 @@ def main_function(cls, recording, peaks, params):
recording, spikes, unit_ids, nbefore, nafter, return_scaled=False, job_name=None, **job_kwargs
)

if d["rank"] is not None:
from spikeinterface.sortingcomponents.matching.circus import compress_templates

_, _, _, templates_array = compress_templates(templates_array, 5)

templates = Templates(
templates_array=templates_array,
sampling_frequency=fs,
Expand All @@ -240,7 +247,10 @@ def main_function(cls, recording, peaks, params):
params["noise_levels"] = get_noise_levels(recording, return_scaled=False)
sparsity = compute_sparsity(templates, noise_levels=params["noise_levels"], **params["sparsity"])
templates = templates.to_sparse(sparsity)
empty_templates = templates.sparsity_mask.sum(axis=1) == 0
templates = remove_empty_templates(templates)
mask = np.isin(peak_labels, np.where(empty_templates)[0])
peak_labels[mask] = -1

if verbose:
print("We found %d raw clusters, starting to clean with matching..." % (len(templates.unit_ids)))
Expand Down
45 changes: 15 additions & 30 deletions src/spikeinterface/sortingcomponents/clustering/clustering_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ def remove_duplicates(
return labels, new_labels


def detect_mixtures(templates, method_kwargs={}, job_kwargs={}, tmp_folder=None, rank=5, multiple_passes=False):
def detect_mixtures(templates, method_kwargs={}, job_kwargs={}, tmp_folder=None, multiple_passes=False):

from spikeinterface.sortingcomponents.matching import find_spikes_from_templates
from spikeinterface.core import BinaryRecordingExtractor, NumpyRecording, SharedMemoryRecording
Expand All @@ -551,19 +551,6 @@ def detect_mixtures(templates, method_kwargs={}, job_kwargs={}, tmp_folder=None,

fs = templates.sampling_frequency
num_chans = len(templates.channel_ids)

if rank is not None:
templates_array = templates.get_dense_templates().copy()
templates_array -= templates_array.mean(axis=(1, 2))[:, None, None]

# Then we keep only the strongest components
temporal, singular, spatial = np.linalg.svd(templates_array, full_matrices=False)
temporal = temporal[:, :, :rank]
singular = singular[:, :rank]
spatial = spatial[:, :rank, :]

templates_array = np.matmul(temporal * singular[:, np.newaxis, :], spatial)

norms = np.linalg.norm(templates_array, axis=(1, 2))
margin = max(templates.nbefore, templates.nafter)
tmp_filename = None
Expand Down Expand Up @@ -591,11 +578,11 @@ def detect_mixtures(templates, method_kwargs={}, job_kwargs={}, tmp_folder=None,
local_params = method_kwargs.copy()
amplitudes = [0.95, 1.05]

local_params.update(
{"templates": templates, "amplitudes": amplitudes, "stop_criteria": "omp_min_sps", "omp_min_sps": 0.5}
)
local_params.update({"templates": templates, "amplitudes": amplitudes})

unit_ids = templates.unit_ids

ignore_ids = []
ignore_inds = []
similar_templates = [[], []]

keep_searching = True
Expand All @@ -605,7 +592,7 @@ def detect_mixtures(templates, method_kwargs={}, job_kwargs={}, tmp_folder=None,

keep_searching = False

for i in list(set(range(nb_templates)).difference(ignore_ids)):
for i in list(set(range(nb_templates)).difference(ignore_inds)):

## Could be speed up by only computing the values for templates that are
## nearby
Expand All @@ -614,7 +601,7 @@ def detect_mixtures(templates, method_kwargs={}, job_kwargs={}, tmp_folder=None,
t_stop = margin + (i + 1) * (duration + margin)

sub_recording = recording.frame_slice(t_start, t_stop)
local_params.update({"ignored_ids": ignore_ids + [i]})
local_params.update({"ignore_inds": ignore_inds + [i]})
spikes, computed = find_spikes_from_templates(
sub_recording, method="circus-omp-svd", method_kwargs=local_params, extra_outputs=True, **job_kwargs
)
Expand Down Expand Up @@ -644,26 +631,25 @@ def detect_mixtures(templates, method_kwargs={}, job_kwargs={}, tmp_folder=None,

tgt_norm = np.linalg.norm(sum)
ratio = tgt_norm / ref_norm

if (amplitudes[0] < ratio) and (ratio < amplitudes[1]):
if multiple_passes:
keep_searching = True
if np.sum(valid) == 1:
ignore_ids += [i]
similar_templates[1] += [i]
similar_templates[0] += [j]
ignore_inds += [i]
similar_templates[1] += [unit_ids[i]]
similar_templates[0] += [unit_ids[j]]
elif np.sum(valid) > 1:
similar_templates[0] += [-1]
ignore_ids += [i]
similar_templates[1] += [i]
ignore_inds += [i]
similar_templates[1] += [unit_ids[i]]

if DEBUG:
import pylab as plt

fig, axes = plt.subplots(1, 2)
from spikeinterface.widgets import plot_traces

plot_traces(sub_recording, ax=axes[0])
# plot_traces(sub_recording, ax=axes[0])
axes[1].plot(templates_array[i].flatten(), label=f"{ref_norm}")
axes[1].plot(sum.flatten(), label=f"{tgt_norm}")
axes[1].legend()
Expand All @@ -678,13 +664,12 @@ def detect_mixtures(templates, method_kwargs={}, job_kwargs={}, tmp_folder=None,


def remove_duplicates_via_matching(
templates, peak_labels, method_kwargs={}, job_kwargs={}, tmp_folder=None, rank=5, multiple_passes=False
templates, peak_labels, method_kwargs={}, job_kwargs={}, tmp_folder=None, multiple_passes=False
):

similar_templates = detect_mixtures(
templates, method_kwargs, job_kwargs, tmp_folder=tmp_folder, rank=rank, multiple_passes=multiple_passes
templates, method_kwargs, job_kwargs, tmp_folder=tmp_folder, multiple_passes=multiple_passes
)

new_labels = peak_labels.copy()

labels = np.unique(new_labels)
Expand Down
92 changes: 66 additions & 26 deletions src/spikeinterface/sortingcomponents/matching/circus.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,38 @@
from .main import BaseTemplateMatchingEngine


def compress_templates(templates_array, approx_rank, remove_mean=True, return_new_templates=True):
"""Compress templates using singular value decomposition.
Parameters
----------
templates : ndarray (num_templates, num_samples, num_channels)
Spike template waveforms.
approx_rank : int
Rank of the compressed template matrices.
Returns
-------
compressed_templates : (ndarray, ndarray, ndarray)
Templates compressed by singular value decomposition into temporal, singular, and spatial components.
"""
if remove_mean:
templates_array -= templates_array.mean(axis=(1, 2))[:, None, None]

temporal, singular, spatial = np.linalg.svd(templates_array, full_matrices=False)
# Keep only the strongest components
temporal = temporal[:, :, :approx_rank]
singular = singular[:, :approx_rank]
spatial = spatial[:, :approx_rank, :]

if return_new_templates:
templates_array = np.matmul(temporal * singular[:, np.newaxis, :], spatial)
else:
templates_array = None

return temporal, singular, spatial, templates_array


def compute_overlaps(templates, num_samples, num_channels, sparsities):
num_templates = len(templates)

Expand Down Expand Up @@ -103,14 +135,14 @@ class CircusOMPSVDPeeler(BaseTemplateMatchingEngine):
"""

_default_params = {
"amplitudes": [0.6, 2],
"amplitudes": [0.6, np.inf],
"stop_criteria": "max_failures",
"max_failures": 20,
"max_failures": 10,
"omp_min_sps": 0.1,
"relative_error": 5e-5,
"templates": None,
"rank": 5,
"ignored_ids": [],
"ignore_inds": [],
"vicinity": 3,
}

Expand All @@ -130,17 +162,8 @@ def _prepare_templates(cls, d):
(d["unit_overlaps_indices"][i],) = np.nonzero(d["units_overlaps"][i])

templates_array = templates.get_dense_templates().copy()
templates_array -= templates_array.mean(axis=(1, 2))[:, None, None]

# Then we keep only the strongest components
rank = d["rank"]
temporal, singular, spatial = np.linalg.svd(templates_array, full_matrices=False)
d["temporal"] = temporal[:, :, :rank]
d["singular"] = singular[:, :rank]
d["spatial"] = spatial[:, :rank, :]

# We reconstruct the approximated templates
templates_array = np.matmul(d["temporal"] * d["singular"][:, np.newaxis, :], d["spatial"])
d["temporal"], d["singular"], d["spatial"], templates_array = compress_templates(templates_array, d["rank"])

d["normed_templates"] = np.zeros(templates_array.shape, dtype=np.float32)
d["norms"] = np.zeros(num_templates, dtype=np.float32)
Expand All @@ -155,6 +178,7 @@ def _prepare_templates(cls, d):
d["temporal"] = np.flip(d["temporal"], axis=1)

d["overlaps"] = []
d["max_similarity"] = np.zeros((num_templates, num_templates), dtype=np.float32)
for i in range(num_templates):
num_overlaps = np.sum(d["units_overlaps"][i])
overlapping_units = np.where(d["units_overlaps"][i])[0]
Expand All @@ -177,8 +201,17 @@ def _prepare_templates(cls, d):
for rank in range(visible_i.shape[1]):
unit_overlaps[count, :] += np.convolve(visible_i[:, rank], d["temporal"][j][:, rank], mode="full")

d["max_similarity"][i, j] = np.max(unit_overlaps[count])

d["overlaps"].append(unit_overlaps)

if d["amplitudes"] is None:
distances = np.sort(d["max_similarity"], axis=1)[:, ::-1]
distances = 1 - distances[:, 1] / 2
d["amplitudes"] = np.zeros((num_templates, 2))
d["amplitudes"][:, 0] = distances
d["amplitudes"][:, 1] = np.inf

d["spatial"] = np.moveaxis(d["spatial"], [0, 1, 2], [1, 0, 2])
d["temporal"] = np.moveaxis(d["temporal"], [0, 1, 2], [1, 2, 0])
d["singular"] = d["singular"].T[:, :, np.newaxis]
Expand Down Expand Up @@ -214,7 +247,7 @@ def initialize_and_check_kwargs(cls, recording, kwargs):
assert d[key] is not None, "If templates are provided, %d should also be there" % key

d["num_templates"] = len(d["templates"].templates_array)
d["ignored_ids"] = np.array(d["ignored_ids"])
d["ignore_inds"] = np.array(d["ignore_inds"])

d["unit_overlaps_tables"] = {}
for i in range(d["num_templates"]):
Expand Down Expand Up @@ -243,15 +276,20 @@ def get_margin(cls, recording, kwargs):
@classmethod
def main_function(cls, traces, d):
num_templates = d["num_templates"]
num_channels = d["num_channels"]
num_samples = d["num_samples"]
num_channels = d["num_channels"]
overlaps_array = d["overlaps"]
norms = d["norms"]
omp_tol = np.finfo(np.float32).eps
num_samples = d["nafter"] + d["nbefore"]
neighbor_window = num_samples - 1
min_amplitude, max_amplitude = d["amplitudes"]
ignored_ids = d["ignored_ids"]
if isinstance(d["amplitudes"], list):
min_amplitude, max_amplitude = d["amplitudes"]
else:
min_amplitude, max_amplitude = d["amplitudes"][:, 0], d["amplitudes"][:, 1]
min_amplitude = min_amplitude[:, np.newaxis]
max_amplitude = max_amplitude[:, np.newaxis]
ignore_inds = d["ignore_inds"]
vicinity = d["vicinity"]

num_timesteps = len(traces)
Expand All @@ -261,15 +299,15 @@ def main_function(cls, traces, d):
scalar_products = np.zeros(conv_shape, dtype=np.float32)

# Filter using overlap-and-add convolution
if len(ignored_ids) > 0:
not_ignored = ~np.isin(np.arange(num_templates), ignored_ids)
if len(ignore_inds) > 0:
not_ignored = ~np.isin(np.arange(num_templates), ignore_inds)
spatially_filtered_data = np.matmul(d["spatial"][:, not_ignored, :], traces.T[np.newaxis, :, :])
scaled_filtered_data = spatially_filtered_data * d["singular"][:, not_ignored, :]
objective_by_rank = scipy.signal.oaconvolve(
scaled_filtered_data, d["temporal"][:, not_ignored, :], axes=2, mode="valid"
)
scalar_products[not_ignored] += np.sum(objective_by_rank, axis=0)
scalar_products[ignored_ids] = -np.inf
scalar_products[ignore_inds] = -np.inf
else:
spatially_filtered_data = np.matmul(d["spatial"], traces.T[np.newaxis, :, :])
scaled_filtered_data = spatially_filtered_data * d["singular"]
Expand All @@ -296,10 +334,10 @@ def main_function(cls, traces, d):
if d["stop_criteria"] == "omp_min_sps":
stop_criteria = d["omp_min_sps"] * np.maximum(d["norms"], np.sqrt(num_channels * num_samples))
elif d["stop_criteria"] == "max_failures":
nb_valids = 0
num_valids = 0
nb_failures = d["max_failures"]
elif d["stop_criteria"] == "relative_error":
if len(ignored_ids) > 0:
if len(ignore_inds) > 0:
new_error = np.linalg.norm(scalar_products[not_ignored])
else:
new_error = np.linalg.norm(scalar_products)
Expand Down Expand Up @@ -409,14 +447,16 @@ def main_function(cls, traces, d):
do_loop = np.any(is_valid)
elif d["stop_criteria"] == "max_failures":
is_valid = (final_amplitudes > min_amplitude) * (final_amplitudes < max_amplitude)
new_nb_valids = np.sum(is_valid)
if (new_nb_valids - nb_valids) == 0:
new_num_valids = np.sum(is_valid)
if (new_num_valids - num_valids) > 0:
nb_failures = d["max_failures"]
else:
nb_failures -= 1
nb_valids = new_nb_valids
num_valids = new_num_valids
do_loop = nb_failures > 0
elif d["stop_criteria"] == "relative_error":
previous_error = new_error
if len(ignored_ids) > 0:
if len(ignore_inds) > 0:
new_error = np.linalg.norm(scalar_products[not_ignored])
else:
new_error = np.linalg.norm(scalar_products)
Expand Down

0 comments on commit 95d2917

Please sign in to comment.