diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index e047cbdd31..5924d3bc18 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -123,20 +123,18 @@ def _prepare_templates(cls, d): else: sparsity = waveform_extractor.sparsity.mask - d['sparsity_mask'] = sparsity - units_overlaps = np.sum( - np.logical_and(sparsity[:, np.newaxis, :], sparsity[np.newaxis, :, :]), axis=2 - ) - d['units_overlaps'] = units_overlaps > 0 - d['unit_overlaps_indices'] = {} + d["sparsity_mask"] = sparsity + units_overlaps = np.sum(np.logical_and(sparsity[:, np.newaxis, :], sparsity[np.newaxis, :, :]), axis=2) + d["units_overlaps"] = units_overlaps > 0 + d["unit_overlaps_indices"] = {} for i in range(num_templates): - d['unit_overlaps_indices'][i], = np.nonzero(d['units_overlaps'][i]) + (d["unit_overlaps_indices"][i],) = np.nonzero(d["units_overlaps"][i]) templates = waveform_extractor.get_all_templates(mode="median").copy() # First, we set masked channels to 0 for count in range(num_templates): - templates[count][:, ~d['sparsity_mask'][count]] = 0 + templates[count][:, ~d["sparsity_mask"][count]] = 0 # Then we keep only the strongest components rank = d["rank"] @@ -153,37 +151,37 @@ def _prepare_templates(cls, d): # And get the norms, saving compressed templates for CC matrix for count in range(num_templates): - template = templates[count][:, d['sparsity_mask'][count]] + template = templates[count][:, d["sparsity_mask"][count]] d["norms"][count] = np.linalg.norm(template) d["templates"][count] = template / d["norms"][count] d["temporal"] /= d["norms"][:, np.newaxis, np.newaxis] d["temporal"] = np.flip(d["temporal"], axis=1) - d['overlaps'] = [] + d["overlaps"] = [] for i in range(num_templates): - num_overlaps = np.sum(d['units_overlaps'][i]) - overlapping_units = np.where(d['units_overlaps'][i])[0] + num_overlaps = np.sum(d["units_overlaps"][i]) + overlapping_units = np.where(d["units_overlaps"][i])[0] # Reconstruct unit template from SVD Matrices - data = d['temporal'][i] * d['singular'][i][np.newaxis, :] - template_i = np.matmul(data, d['spatial'][i, :, :]) + data = d["temporal"][i] * d["singular"][i][np.newaxis, :] + template_i = np.matmul(data, d["spatial"][i, :, :]) template_i = np.flipud(template_i) - unit_overlaps = np.zeros([num_overlaps, 2*d['num_samples'] - 1], dtype=np.float32) + unit_overlaps = np.zeros([num_overlaps, 2 * d["num_samples"] - 1], dtype=np.float32) for count, j in enumerate(overlapping_units): - overlapped_channels = d['sparsity_mask'][j] + overlapped_channels = d["sparsity_mask"][j] visible_i = template_i[:, overlapped_channels] - spatial_filters = d['spatial'][j, :, overlapped_channels] + spatial_filters = d["spatial"][j, :, overlapped_channels] spatially_filtered_template = np.matmul(visible_i, spatial_filters) - visible_i = spatially_filtered_template * d['singular'][j] - + visible_i = spatially_filtered_template * d["singular"][j] + for rank in range(visible_i.shape[1]): - unit_overlaps[count, :] += np.convolve(visible_i[:, rank], d['temporal'][j][:, rank], mode='full') + unit_overlaps[count, :] += np.convolve(visible_i[:, rank], d["temporal"][j][:, rank], mode="full") - d['overlaps'].append(unit_overlaps) + d["overlaps"].append(unit_overlaps) d["spatial"] = np.moveaxis(d["spatial"], [0, 1, 2], [1, 0, 2]) d["temporal"] = np.moveaxis(d["temporal"], [0, 1, 2], [1, 2, 0]) @@ -214,7 +212,15 @@ def initialize_and_check_kwargs(cls, recording, kwargs): if "templates" not in d: d = cls._prepare_templates(d) else: - for key in ["norms", "temporal", "spatial", "singular", "units_overlaps", "sparsity_mask", "unit_overlaps_indices"]: + for key in [ + "norms", + "temporal", + "spatial", + "singular", + "units_overlaps", + "sparsity_mask", + "unit_overlaps_indices", + ]: assert d[key] is not None, "If templates are provided, %d should also be there" % key d["num_templates"] = len(d["templates"]) @@ -307,7 +313,7 @@ def main_function(cls, traces, d): myindices = selection[0, idx] local_overlaps = overlaps[best_cluster_ind] - overlapping_templates = d['unit_overlaps_indices'][best_cluster_ind] + overlapping_templates = d["unit_overlaps_indices"][best_cluster_ind] if num_selection == M.shape[0]: Z = np.zeros((2 * num_selection, 2 * num_selection), dtype=np.float32) @@ -382,7 +388,7 @@ def main_function(cls, traces, d): diff_amp = diff_amplitudes[i] * norms[tmp_best] local_overlaps = overlaps[tmp_best] - overlapping_templates = d['units_overlaps'][tmp_best] + overlapping_templates = d["units_overlaps"][tmp_best] if not tmp_peak in neighbors.keys(): idx = [max(0, tmp_peak - num_samples), min(num_peaks, tmp_peak + neighbor_window)]