Skip to content

Commit

Permalink
Merge branch 'svd_convolutions' of github.com:yger/spikeinterface int…
Browse files Browse the repository at this point in the history
…o svd_convolutions
  • Loading branch information
yger committed Sep 27, 2023
2 parents 8da6b79 + 97aff7f commit e4189a9
Showing 1 changed file with 30 additions and 24 deletions.
54 changes: 30 additions & 24 deletions src/spikeinterface/sortingcomponents/matching/circus.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,20 +529,18 @@ def _prepare_templates(cls, d):
else:
sparsity = waveform_extractor.sparsity.mask

d['sparsity_mask'] = sparsity
units_overlaps = np.sum(
np.logical_and(sparsity[:, np.newaxis, :], sparsity[np.newaxis, :, :]), axis=2
)
d['units_overlaps'] = units_overlaps > 0
d['unit_overlaps_indices'] = {}
d["sparsity_mask"] = sparsity
units_overlaps = np.sum(np.logical_and(sparsity[:, np.newaxis, :], sparsity[np.newaxis, :, :]), axis=2)
d["units_overlaps"] = units_overlaps > 0
d["unit_overlaps_indices"] = {}
for i in range(num_templates):
d['unit_overlaps_indices'][i], = np.nonzero(d['units_overlaps'][i])
(d["unit_overlaps_indices"][i],) = np.nonzero(d["units_overlaps"][i])

templates = waveform_extractor.get_all_templates(mode="median").copy()

# First, we set masked channels to 0
for count in range(num_templates):
templates[count][:, ~d['sparsity_mask'][count]] = 0
templates[count][:, ~d["sparsity_mask"][count]] = 0

# Then we keep only the strongest components
rank = d["rank"]
Expand All @@ -559,37 +557,37 @@ def _prepare_templates(cls, d):

# And get the norms, saving compressed templates for CC matrix
for count in range(num_templates):
template = templates[count][:, d['sparsity_mask'][count]]
template = templates[count][:, d["sparsity_mask"][count]]
d["norms"][count] = np.linalg.norm(template)
d["templates"][count] = template / d["norms"][count]

d["temporal"] /= d["norms"][:, np.newaxis, np.newaxis]
d["temporal"] = np.flip(d["temporal"], axis=1)

d['overlaps'] = []
d["overlaps"] = []
for i in range(num_templates):
num_overlaps = np.sum(d['units_overlaps'][i])
overlapping_units = np.where(d['units_overlaps'][i])[0]
num_overlaps = np.sum(d["units_overlaps"][i])
overlapping_units = np.where(d["units_overlaps"][i])[0]

# Reconstruct unit template from SVD Matrices
data = d['temporal'][i] * d['singular'][i][np.newaxis, :]
template_i = np.matmul(data, d['spatial'][i, :, :])
data = d["temporal"][i] * d["singular"][i][np.newaxis, :]
template_i = np.matmul(data, d["spatial"][i, :, :])
template_i = np.flipud(template_i)

unit_overlaps = np.zeros([num_overlaps, 2*d['num_samples'] - 1], dtype=np.float32)
unit_overlaps = np.zeros([num_overlaps, 2 * d["num_samples"] - 1], dtype=np.float32)

for count, j in enumerate(overlapping_units):
overlapped_channels = d['sparsity_mask'][j]
overlapped_channels = d["sparsity_mask"][j]
visible_i = template_i[:, overlapped_channels]

spatial_filters = d['spatial'][j, :, overlapped_channels]
spatial_filters = d["spatial"][j, :, overlapped_channels]
spatially_filtered_template = np.matmul(visible_i, spatial_filters)
visible_i = spatially_filtered_template * d['singular'][j]
visible_i = spatially_filtered_template * d["singular"][j]

for rank in range(visible_i.shape[1]):
unit_overlaps[count, :] += np.convolve(visible_i[:, rank], d['temporal'][j][:, rank], mode='full')
unit_overlaps[count, :] += np.convolve(visible_i[:, rank], d["temporal"][j][:, rank], mode="full")

d['overlaps'].append(unit_overlaps)
d["overlaps"].append(unit_overlaps)

d["spatial"] = np.moveaxis(d["spatial"], [0, 1, 2], [1, 0, 2])
d["temporal"] = np.moveaxis(d["temporal"], [0, 1, 2], [1, 2, 0])
Expand Down Expand Up @@ -620,7 +618,15 @@ def initialize_and_check_kwargs(cls, recording, kwargs):
if "templates" not in d:
d = cls._prepare_templates(d)
else:
for key in ["norms", "temporal", "spatial", "singular", "units_overlaps", "sparsity_mask", "unit_overlaps_indices"]:
for key in [
"norms",
"temporal",
"spatial",
"singular",
"units_overlaps",
"sparsity_mask",
"unit_overlaps_indices",
]:
assert d[key] is not None, "If templates are provided, %d should also be there" % key

d["num_templates"] = len(d["templates"])
Expand Down Expand Up @@ -713,7 +719,7 @@ def main_function(cls, traces, d):
myindices = selection[0, idx]

local_overlaps = overlaps[best_cluster_ind]
overlapping_templates = d['unit_overlaps_indices'][best_cluster_ind]
overlapping_templates = d["unit_overlaps_indices"][best_cluster_ind]

if num_selection == M.shape[0]:
Z = np.zeros((2 * num_selection, 2 * num_selection), dtype=np.float32)
Expand Down Expand Up @@ -788,7 +794,7 @@ def main_function(cls, traces, d):
diff_amp = diff_amplitudes[i] * norms[tmp_best]

local_overlaps = overlaps[tmp_best]
overlapping_templates = d['units_overlaps'][tmp_best]
overlapping_templates = d["units_overlaps"][tmp_best]

if not tmp_peak in neighbors.keys():
idx = [max(0, tmp_peak - num_samples), min(num_peaks, tmp_peak + neighbor_window)]
Expand Down

0 comments on commit e4189a9

Please sign in to comment.