From 9dc04f1fa68cf7202eed224394bb60b95a7b4e6d Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 14 Jul 2023 13:44:25 +0200 Subject: [PATCH 01/22] WIP --- .../sortingcomponents/matching/circus.py | 517 ++++++++---------- 1 file changed, 218 insertions(+), 299 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index 2196320378..8f08aac9c5 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -16,7 +16,8 @@ except ImportError: HAVE_SKLEARN = False -from spikeinterface.core import get_noise_levels, get_random_data_chunks + +from spikeinterface.core import get_noise_levels, get_random_data_chunks, compute_sparsity from spikeinterface.sortingcomponents.peak_detection import DetectPeakByChannel (potrs,) = scipy.linalg.get_lapack_funcs(("potrs",), dtype=np.float32) @@ -130,6 +131,38 @@ def _freq_domain_conv(in1, in2, axes, shape, cache, calc_fast_len=True): return ret +def compute_overlaps(templates, num_samples, num_channels, sparsities): + + num_templates = len(templates) + + dense_templates = np.zeros((num_templates, num_samples, num_channels), dtype=np.float32) + for i in range(num_templates): + dense_templates[i, :, sparsities[i]] = templates[i].T + + size = 2 * num_samples - 1 + + all_delays = list(range(0, num_samples+1)) + + overlaps = {} + + for delay in all_delays: + source = dense_templates[:, :delay, :].reshape(num_templates, -1) + target = dense_templates[:, num_samples-delay:, :].reshape(num_templates, -1) + + overlaps[delay] = scipy.sparse.csr_matrix(source.dot(target.T)) + + if delay < num_samples: + overlaps[size - delay + 1] = overlaps[delay].T.tocsr() + + new_overlaps = [] + + for i in range(num_templates): + data = [overlaps[j][i, :].T for j in range(size)] + data = scipy.sparse.hstack(data) + new_overlaps += [data] + + return new_overlaps + class CircusOMPPeeler(BaseTemplateMatchingEngine): """ @@ -152,11 +185,6 @@ class CircusOMPPeeler(BaseTemplateMatchingEngine): (Minimal, Maximal) amplitudes allowed for every template omp_min_sps: float Stopping criteria of the OMP algorithm, in percentage of the norm - sparsify_threshold: float - Templates are sparsified in order to keep only the channels necessary - to explain. ptp limit for considering a channel as silent - smoothing_factor: float - Templates are smoothed via Spline Interpolation noise_levels: array The noise levels, for every channels. If None, they will be automatically computed @@ -175,133 +203,77 @@ class CircusOMPPeeler(BaseTemplateMatchingEngine): "norms": None, "random_chunk_kwargs": {}, "noise_levels": None, - "smoothing_factor": 0.25, + 'sparse_kwargs' : {'method' : 'ptp', 'threshold' : 1}, "ignored_ids": [], + "vicinity" : 0 } - @classmethod - def _sparsify_template(cls, template, sparsify_threshold): - is_silent = template.ptp(0) < sparsify_threshold - template[:, is_silent] = 0 - (active_channels,) = np.where(np.logical_not(is_silent)) - - return template, active_channels - - @classmethod - def _regularize_template(cls, template, smoothing_factor=0.25): - nb_channels = template.shape[1] - nb_timesteps = template.shape[0] - xaxis = np.arange(nb_timesteps) - for i in range(nb_channels): - z = scipy.interpolate.UnivariateSpline(xaxis, template[:, i]) - z.set_smoothing_factor(smoothing_factor) - template[:, i] = z(xaxis) - return template - @classmethod def _prepare_templates(cls, d): - waveform_extractor = d["waveform_extractor"] - num_samples = d["num_samples"] - num_channels = d["num_channels"] - num_templates = len(d["waveform_extractor"].sorting.unit_ids) + + waveform_extractor = d['waveform_extractor'] + num_templates = len(d['waveform_extractor'].sorting.unit_ids) - templates = waveform_extractor.get_all_templates(mode="median").copy() + if not waveform_extractor.is_sparse(): + sparsity = compute_sparsity(waveform_extractor, **d['sparse_kwargs']).mask + else: + sparsity = waveform_extractor.sparsity.mask + + templates = waveform_extractor.get_all_templates(mode='median').copy() - d["sparsities"] = {} - d["templates"] = {} - d["norms"] = np.zeros(num_templates, dtype=np.float32) + d['sparsities'] = {} + d['templates'] = {} + d['norms'] = np.zeros(num_templates, dtype=np.float32) for count, unit_id in enumerate(waveform_extractor.sorting.unit_ids): - if d["smoothing_factor"] > 0: - template = cls._regularize_template(templates[count], d["smoothing_factor"]) - else: - template = templates[count] - template, active_channels = cls._sparsify_template(template, d["sparsify_threshold"]) - d["sparsities"][count] = active_channels - d["norms"][count] = np.linalg.norm(template) - d["templates"][count] = template[:, active_channels] / d["norms"][count] - - return d - - @classmethod - def _prepare_overlaps(cls, d): - templates = d["templates"] - num_samples = d["num_samples"] - num_channels = d["num_channels"] - num_templates = d["num_templates"] - sparsities = d["sparsities"] - - dense_templates = np.zeros((num_templates, num_samples, num_channels), dtype=np.float32) - for i in range(num_templates): - dense_templates[i, :, sparsities[i]] = templates[i].T - - size = 2 * num_samples - 1 - - all_delays = list(range(0, num_samples + 1)) - - overlaps = {} - - for delay in all_delays: - source = dense_templates[:, :delay, :].reshape(num_templates, -1) - target = dense_templates[:, num_samples - delay :, :].reshape(num_templates, -1) - - overlaps[delay] = scipy.sparse.csr_matrix(source.dot(target.T)) - - if delay < num_samples: - overlaps[size - delay + 1] = overlaps[delay].T.tocsr() - - new_overlaps = [] - - for i in range(num_templates): - data = [overlaps[j][i, :].T for j in range(size)] - data = scipy.sparse.hstack(data) - new_overlaps += [data] - - d["overlaps"] = new_overlaps + template = templates[count] + d['sparsities'][count], = np.nonzero(sparsity[count]) + d['norms'][count] = np.linalg.norm(template) + d['templates'][count] = template[:, d['sparsities'][count]]/d['norms'][count] return d @classmethod def initialize_and_check_kwargs(cls, recording, kwargs): + d = cls._default_params.copy() d.update(kwargs) - # assert isinstance(d['waveform_extractor'], WaveformExtractor) - - for v in ["omp_min_sps"]: - assert (d[v] >= 0) and (d[v] <= 1), f"{v} should be in [0, 1]" + #assert isinstance(d['waveform_extractor'], WaveformExtractor) - d["num_channels"] = d["waveform_extractor"].recording.get_num_channels() - d["num_samples"] = d["waveform_extractor"].nsamples - d["nbefore"] = d["waveform_extractor"].nbefore - d["nafter"] = d["waveform_extractor"].nafter - d["sampling_frequency"] = d["waveform_extractor"].recording.get_sampling_frequency() + for v in ['omp_min_sps']: + assert (d[v] >= 0) and (d[v] <= 1), f'{v} should be in [0, 1]' + + d['num_channels'] = d['waveform_extractor'].recording.get_num_channels() + d['num_samples'] = d['waveform_extractor'].nsamples + d['nbefore'] = d['waveform_extractor'].nbefore + d['nafter'] = d['waveform_extractor'].nafter + d['sampling_frequency'] = d['waveform_extractor'].recording.get_sampling_frequency() + d['vicinity'] *= d['num_samples'] - if d["noise_levels"] is None: - print("CircusOMPPeeler : noise should be computed outside") - d["noise_levels"] = get_noise_levels(recording, **d["random_chunk_kwargs"], return_scaled=False) + if d['noise_levels'] is None: + print('CircusOMPPeeler : noise should be computed outside') + d['noise_levels'] = get_noise_levels(recording, **d['random_chunk_kwargs'], return_scaled=False) - if d["templates"] is None: + if d['templates'] is None: d = cls._prepare_templates(d) else: - for key in ["norms", "sparsities"]: - assert d[key] is not None, "If templates are provided, %d should also be there" % key + for key in ['norms', 'sparsities']: + assert d[key] is not None, "If templates are provided, %d should also be there" %key - d["num_templates"] = len(d["templates"]) + d['num_templates'] = len(d['templates']) - if d["overlaps"] is None: - d = cls._prepare_overlaps(d) + if d['overlaps'] is None: + d['overlaps'] = compute_overlaps(d['templates'], d['num_samples'], d['num_channels'], d['sparsities']) - d["ignored_ids"] = np.array(d["ignored_ids"]) + d['ignored_ids'] = np.array(d['ignored_ids']) - omp_min_sps = d["omp_min_sps"] - norms = d["norms"] - sparsities = d["sparsities"] + omp_min_sps = d['omp_min_sps'] + nb_active_channels = np.array([len(d['sparsities'][count]) for count in range(d['num_templates'])]) + d['stop_criteria'] = omp_min_sps * np.sqrt(nb_active_channels * d['num_samples']) - nb_active_channels = np.array([len(sparsities[i]) for i in range(d["num_templates"])]) - d["stop_criteria"] = omp_min_sps * np.sqrt(d["noise_levels"].sum() * d["num_samples"]) + return d - return d @classmethod def serialize_method_kwargs(cls, kwargs): @@ -321,26 +293,27 @@ def get_margin(cls, recording, kwargs): @classmethod def main_function(cls, traces, d): - templates = d["templates"] - num_templates = d["num_templates"] - num_channels = d["num_channels"] - num_samples = d["num_samples"] - overlaps = d["overlaps"] - norms = d["norms"] - nbefore = d["nbefore"] - nafter = d["nafter"] + templates = d['templates'] + num_templates = d['num_templates'] + num_channels = d['num_channels'] + num_samples = d['num_samples'] + overlaps = d['overlaps'] + norms = d['norms'] + nbefore = d['nbefore'] + nafter = d['nafter'] omp_tol = np.finfo(np.float32).eps - num_samples = d["nafter"] + d["nbefore"] + num_samples = d['nafter'] + d['nbefore'] neighbor_window = num_samples - 1 - min_amplitude, max_amplitude = d["amplitudes"] - sparsities = d["sparsities"] - ignored_ids = d["ignored_ids"] - stop_criteria = d["stop_criteria"] + min_amplitude, max_amplitude = d['amplitudes'] + sparsities = d['sparsities'] + ignored_ids = d['ignored_ids'] + stop_criteria = d['stop_criteria'][:, np.newaxis] + vicinity = d['vicinity'] - if "cached_fft_kernels" not in d: - d["cached_fft_kernels"] = {"fshape": 0} + if 'cached_fft_kernels' not in d: + d['cached_fft_kernels'] = {'fshape' : 0} - cached_fft_kernels = d["cached_fft_kernels"] + cached_fft_kernels = d['cached_fft_kernels'] num_timesteps = len(traces) @@ -352,22 +325,24 @@ def main_function(cls, traces, d): dummy_traces = np.empty((num_channels, num_timesteps), dtype=np.float32) fshape, axes = get_scipy_shape(dummy_filter, traces, axes=1) - fft_cache = {"full": sp_fft.rfftn(traces, fshape, axes=axes)} + fft_cache = {'full' : sp_fft.rfftn(traces, fshape, axes=axes)} scalar_products = np.empty((num_templates, num_peaks), dtype=np.float32) - flagged_chunk = cached_fft_kernels["fshape"] != fshape[0] + flagged_chunk = cached_fft_kernels['fshape'] != fshape[0] for i in range(num_templates): + if i not in ignored_ids: + if i not in cached_fft_kernels or flagged_chunk: kernel_filter = np.ascontiguousarray(templates[i][::-1].T) - cached_fft_kernels.update({i: sp_fft.rfftn(kernel_filter, fshape, axes=axes)}) - cached_fft_kernels["fshape"] = fshape[0] + cached_fft_kernels.update({i : sp_fft.rfftn(kernel_filter, fshape, axes=axes)}) + cached_fft_kernels['fshape'] = fshape[0] - fft_cache.update({"mask": sparsities[i], "template": cached_fft_kernels[i]}) + fft_cache.update({'mask' : sparsities[i], 'template' : cached_fft_kernels[i]}) - convolution = fftconvolve_with_cache(dummy_filter, dummy_traces, fft_cache, axes=1, mode="valid") + convolution = fftconvolve_with_cache(dummy_filter, dummy_traces, fft_cache, axes=1, mode='valid') if len(convolution) > 0: scalar_products[i] = convolution.sum(0) else: @@ -381,7 +356,7 @@ def main_function(cls, traces, d): spikes = np.empty(scalar_products.size, dtype=spike_dtype) idx_lookup = np.arange(scalar_products.size).reshape(num_templates, -1) - M = np.zeros((num_peaks, num_peaks), dtype=np.float32) + M = np.zeros((100, 100), dtype=np.float32) all_selections = np.empty((2, scalar_products.size), dtype=np.int32) final_amplitudes = np.zeros(scalar_products.shape, dtype=np.float32) @@ -392,13 +367,17 @@ def main_function(cls, traces, d): neighbors = {} cached_overlaps = {} - is_valid = scalar_products > stop_criteria + is_valid = (scalar_products > stop_criteria) + all_amplitudes = np.zeros(0, dtype=np.float32) + is_in_vicinity = np.zeros(0, dtype=np.int32) while np.any(is_valid): + best_amplitude_ind = scalar_products[is_valid].argmax() best_cluster_ind, peak_index = np.unravel_index(idx_lookup[is_valid][best_amplitude_ind], idx_lookup.shape) - + if num_selection > 0: + delta_t = selection[1] - peak_index idx = np.where((delta_t < neighbor_window) & (delta_t > -num_samples))[0] myline = num_samples + delta_t[idx] @@ -407,25 +386,42 @@ def main_function(cls, traces, d): cached_overlaps[best_cluster_ind] = overlaps[best_cluster_ind].toarray() if num_selection == M.shape[0]: - Z = np.zeros((2 * num_selection, 2 * num_selection), dtype=np.float32) + Z = np.zeros((2*num_selection, 2*num_selection), dtype=np.float32) Z[:num_selection, :num_selection] = M M = Z M[num_selection, idx] = cached_overlaps[best_cluster_ind][selection[0, idx], myline] - scipy.linalg.solve_triangular( - M[:num_selection, :num_selection], - M[num_selection, :num_selection], - trans=0, - lower=1, - overwrite_b=True, - check_finite=False, - ) - - v = nrm2(M[num_selection, :num_selection]) ** 2 - Lkk = 1 - v - if Lkk <= omp_tol: # selected atoms are dependent - break - M[num_selection, num_selection] = np.sqrt(Lkk) + + if vicinity == 0: + scipy.linalg.solve_triangular(M[:num_selection, :num_selection], M[num_selection, :num_selection], trans=0, + lower=1, + overwrite_b=True, + check_finite=False) + + v = nrm2(M[num_selection, :num_selection]) ** 2 + Lkk = 1 - v + if Lkk <= omp_tol: # selected atoms are dependent + break + M[num_selection, num_selection] = np.sqrt(Lkk) + else: + is_in_vicinity = np.where(np.abs(delta_t) < vicinity)[0] + + if len(is_in_vicinity) > 0: + + L = M[is_in_vicinity, :][:, is_in_vicinity] + + M[num_selection, is_in_vicinity] = scipy.linalg.solve_triangular(L, M[num_selection, is_in_vicinity], trans=0, + lower=1, + overwrite_b=True, + check_finite=False) + + v = nrm2(M[num_selection, is_in_vicinity]) ** 2 + Lkk = 1 - v + if Lkk <= omp_tol: # selected atoms are dependent + break + M[num_selection, num_selection] = np.sqrt(Lkk) + else: + M[num_selection, num_selection] = 1.0 else: M[0, 0] = 1 @@ -435,45 +431,54 @@ def main_function(cls, traces, d): selection = all_selections[:, :num_selection] res_sps = full_sps[selection[0], selection[1]] - all_amplitudes, _ = potrs(M[:num_selection, :num_selection], res_sps, lower=True, overwrite_b=False) - - all_amplitudes /= norms[selection[0]] - - diff_amplitudes = all_amplitudes - final_amplitudes[selection[0], selection[1]] + if vicinity == 0: + all_amplitudes, _ = potrs(M[:num_selection, :num_selection], res_sps, + lower=True, overwrite_b=False) + all_amplitudes /= norms[selection[0]] + else: + is_in_vicinity = np.append(is_in_vicinity, num_selection - 1) + all_amplitudes = np.append(all_amplitudes, np.float32(0)) + L = M[is_in_vicinity, :][:, is_in_vicinity] + all_amplitudes[is_in_vicinity], _ = potrs(L, res_sps[is_in_vicinity], + lower=True, overwrite_b=False) + all_amplitudes[is_in_vicinity] /= norms[selection[0][is_in_vicinity]] + + diff_amplitudes = (all_amplitudes - final_amplitudes[selection[0], selection[1]]) modified = np.where(np.abs(diff_amplitudes) > omp_tol)[0] final_amplitudes[selection[0], selection[1]] = all_amplitudes for i in modified: - tmp_best, tmp_peak = selection[:, i] - diff_amp = diff_amplitudes[i] * norms[tmp_best] + tmp_best, tmp_peak = selection[:, i] + diff_amp = diff_amplitudes[i]*norms[tmp_best] + if not tmp_best in cached_overlaps: cached_overlaps[tmp_best] = overlaps[tmp_best].toarray() if not tmp_peak in neighbors.keys(): idx = [max(0, tmp_peak - num_samples), min(num_peaks, tmp_peak + neighbor_window)] tdx = [num_samples + idx[0] - tmp_peak, num_samples + idx[1] - tmp_peak] - neighbors[tmp_peak] = {"idx": idx, "tdx": tdx} + neighbors[tmp_peak] = {'idx' : idx, 'tdx' : tdx} - idx = neighbors[tmp_peak]["idx"] - tdx = neighbors[tmp_peak]["tdx"] + idx = neighbors[tmp_peak]['idx'] + tdx = neighbors[tmp_peak]['tdx'] - to_add = diff_amp * cached_overlaps[tmp_best][:, tdx[0] : tdx[1]] - scalar_products[:, idx[0] : idx[1]] -= to_add + to_add = diff_amp * cached_overlaps[tmp_best][:, tdx[0]:tdx[1]] + scalar_products[:, idx[0]:idx[1]] -= to_add - is_valid = scalar_products > stop_criteria + is_valid = (scalar_products > stop_criteria) - is_valid = (final_amplitudes > min_amplitude) * (final_amplitudes < max_amplitude) + is_valid = (final_amplitudes > min_amplitude)*(final_amplitudes < max_amplitude) valid_indices = np.where(is_valid) num_spikes = len(valid_indices[0]) - spikes["sample_index"][:num_spikes] = valid_indices[1] + d["nbefore"] - spikes["channel_index"][:num_spikes] = 0 - spikes["cluster_index"][:num_spikes] = valid_indices[0] - spikes["amplitude"][:num_spikes] = final_amplitudes[valid_indices[0], valid_indices[1]] - + spikes['sample_index'][:num_spikes] = valid_indices[1] + d['nbefore'] + spikes['channel_index'][:num_spikes] = 0 + spikes['cluster_index'][:num_spikes] = valid_indices[0] + spikes['amplitude'][:num_spikes] = final_amplitudes[valid_indices[0], valid_indices[1]] + spikes = spikes[:num_spikes] - order = np.argsort(spikes["sample_index"]) + order = np.argsort(spikes['sample_index']) spikes = spikes[order] return spikes @@ -515,9 +520,6 @@ class CircusPeeler(BaseTemplateMatchingEngine): Maximal amplitude allowed for every template min_amplitude: float Minimal amplitude allowed for every template - sparsify_threshold: float - Templates are sparsified in order to keep only the channels necessary - to explain a given fraction of the total norm use_sparse_matrix_threshold: float If density of the templates is below a given threshold, sparse matrix are used (memory efficient) @@ -529,129 +531,57 @@ class CircusPeeler(BaseTemplateMatchingEngine): """ _default_params = { - "peak_sign": "neg", - "exclude_sweep_ms": 0.1, - "jitter_ms": 0.1, - "detect_threshold": 5, - "noise_levels": None, - "random_chunk_kwargs": {}, - "sparsify_threshold": 0.99, - "max_amplitude": 1.5, - "min_amplitude": 0.5, - "use_sparse_matrix_threshold": 0.25, - "progess_bar_steps": False, - "waveform_extractor": None, - "smoothing_factor": 0.25, + 'peak_sign': 'neg', + 'exclude_sweep_ms': 0.1, + 'jitter_ms' : 0.1, + 'detect_threshold': 5, + 'noise_levels': None, + 'random_chunk_kwargs': {}, + 'max_amplitude' : 1.5, + 'min_amplitude' : 0.5, + 'use_sparse_matrix_threshold' : 0.25, + 'progess_bar_steps' : False, + 'waveform_extractor': None, + 'sparse_kwargs' : {'method' : 'threshold', 'threshold' : 0.5, 'peak_sign' : 'both'} } - @classmethod - def _sparsify_template(cls, template, sparsify_threshold, noise_levels): - is_silent = template.std(0) < 0.1 * noise_levels - - template[:, is_silent] = 0 - - channel_norms = np.linalg.norm(template, axis=0) ** 2 - total_norm = np.linalg.norm(template) ** 2 - - idx = np.argsort(channel_norms)[::-1] - explained_norms = np.cumsum(channel_norms[idx] / total_norm) - channel = np.searchsorted(explained_norms, sparsify_threshold) - active_channels = np.sort(idx[:channel]) - template[:, idx[channel:]] = 0 - return template, active_channels - - @classmethod - def _regularize_template(cls, template, smoothing_factor=0.25): - nb_channels = template.shape[1] - nb_timesteps = template.shape[0] - xaxis = np.arange(nb_timesteps) - for i in range(nb_channels): - z = scipy.interpolate.UnivariateSpline(xaxis, template[:, i]) - z.set_smoothing_factor(smoothing_factor) - template[:, i] = z(xaxis) - return template - @classmethod def _prepare_templates(cls, d): - parameters = d - waveform_extractor = parameters["waveform_extractor"] - num_samples = parameters["num_samples"] - num_channels = parameters["num_channels"] - num_templates = parameters["num_templates"] - max_amplitude = parameters["max_amplitude"] - min_amplitude = parameters["min_amplitude"] - use_sparse_matrix_threshold = parameters["use_sparse_matrix_threshold"] + + waveform_extractor = d['waveform_extractor'] + num_samples = d['num_samples'] + num_channels = d['num_channels'] + num_templates = d['num_templates'] + use_sparse_matrix_threshold = d['use_sparse_matrix_threshold'] - parameters["norms"] = np.zeros(num_templates, dtype=np.float32) + d['norms'] = np.zeros(num_templates, dtype=np.float32) - all_units = list(parameters["waveform_extractor"].sorting.unit_ids) + all_units = list(d['waveform_extractor'].sorting.unit_ids) - templates = waveform_extractor.get_all_templates(mode="median").copy() + if not waveform_extractor.is_sparse(): + sparsity = compute_sparsity(waveform_extractor, **d['sparse_kwargs']).mask + templates = waveform_extractor.get_all_templates(mode='median').copy() + d['sparsities'] = {} + for count, unit_id in enumerate(all_units): - if parameters["smoothing_factor"] > 0: - templates[count] = cls._regularize_template(templates[count], parameters["smoothing_factor"]) - templates[count], _ = cls._sparsify_template( - templates[count], parameters["sparsify_threshold"], parameters["noise_levels"] - ) - parameters["norms"][count] = np.linalg.norm(templates[count]) - templates[count] /= parameters["norms"][count] + d['sparsities'][count], = np.nonzero(sparsity[count]) + templates[count][sparsity[count] == False] = 0 + d['norms'][count] = np.linalg.norm(templates[count]) + templates[count] /= d['norms'][count] templates = templates.reshape(num_templates, -1) - nnz = np.sum(templates != 0) / (num_templates * num_samples * num_channels) + nnz = np.sum(templates != 0)/(num_templates * num_samples * num_channels) if nnz <= use_sparse_matrix_threshold: templates = scipy.sparse.csr_matrix(templates) - print(f"Templates are automatically sparsified (sparsity level is {nnz})") - parameters["is_dense"] = False - else: - parameters["is_dense"] = True - - parameters["templates"] = templates - - return parameters - - @classmethod - def _prepare_overlaps(cls, d): - templates = d["templates"] - num_samples = d["num_samples"] - num_channels = d["num_channels"] - num_templates = d["num_templates"] - is_dense = d["is_dense"] - - if not is_dense: - dense_templates = templates.toarray() + print(f'Templates are automatically sparsified (sparsity level is {nnz})') + d['is_dense'] = False else: - dense_templates = templates - - dense_templates = dense_templates.reshape(num_templates, num_samples, num_channels) - - size = 2 * num_samples - 1 - - all_delays = list(range(0, num_samples + 1)) - if d["progess_bar_steps"]: - all_delays = tqdm(all_delays, desc="[1] compute overlaps") - - overlaps = {} - - for delay in all_delays: - source = dense_templates[:, :delay, :].reshape(num_templates, -1) - target = dense_templates[:, num_samples - delay :, :].reshape(num_templates, -1) - - overlaps[delay] = scipy.sparse.csr_matrix(source.dot(target.T)) + d['is_dense'] = True - if delay < num_samples: - overlaps[size - delay] = overlaps[delay].T.tocsr() - - new_overlaps = [] - - for i in range(num_templates): - data = [overlaps[j][i, :].T for j in range(size)] - data = scipy.sparse.hstack(data) - new_overlaps += [data] - - d["overlaps"] = new_overlaps + d['templates'] = templates return d @@ -661,9 +591,9 @@ def _mcc_error(cls, bounds, good, bad): fp = np.sum((bounds[0] <= bad) & (bad <= bounds[1])) tp = np.sum((bounds[0] <= good) & (good <= bounds[1])) tn = np.sum((bad < bounds[0]) | (bad > bounds[1])) - denom = (tp + fp) * (tp + fn) * (tn + fp) * (tn + fn) + denom = (tp+fp)*(tp+fn)*(tn+fp)*(tn+fn) if denom > 0: - mcc = 1 - (tp * tn - fp * fn) / np.sqrt(denom) + mcc = 1 - (tp*tn - fp*fn)/np.sqrt(denom) else: mcc = 1 return mcc @@ -708,16 +638,6 @@ def _optimize_amplitudes(cls, noise_snippets, d): res = scipy.optimize.differential_evolution(cls._cost_function_mcc, bounds=cost_bounds, args=cost_kwargs) parameters["amplitudes"][count] = res.x - # import pylab as plt - # plt.hist(good, 100, alpha=0.5) - # plt.hist(bad, 100, alpha=0.5) - # plt.hist(noise[count], 100, alpha=0.5) - # ymin, ymax = plt.ylim() - # plt.plot([res.x[0], res.x[0]], [ymin, ymax], 'k--') - # plt.plot([res.x[1], res.x[1]], [ymin, ymax], 'k--') - # plt.savefig('test_%d.png' %count) - # plt.close() - return d @classmethod @@ -727,7 +647,6 @@ def initialize_and_check_kwargs(cls, recording, kwargs): default_parameters.update(kwargs) # assert isinstance(d['waveform_extractor'], WaveformExtractor) - for v in ["sparsify_threshold", "use_sparse_matrix_threshold"]: assert (default_parameters[v] >= 0) and (default_parameters[v] <= 1), f"{v} should be in [0, 1]" @@ -817,31 +736,31 @@ def main_function(cls, traces, d): sym_patch = d["sym_patch"] peak_traces = traces[margin // 2 : -margin // 2, :] - peak_sample_ind, peak_chan_ind = DetectPeakByChannel.detect_peaks( + peak_sample_index, peak_chan_ind = DetectPeakByChannel.detect_peaks( peak_traces, peak_sign, abs_threholds, exclude_sweep_size ) if jitter > 0: - jittered_peaks = peak_sample_ind[:, np.newaxis] + np.arange(-jitter, jitter) + jittered_peaks = peak_sample_index[:, np.newaxis] + np.arange(-jitter, jitter) jittered_channels = peak_chan_ind[:, np.newaxis] + np.zeros(2 * jitter) mask = (jittered_peaks > 0) & (jittered_peaks < len(peak_traces)) jittered_peaks = jittered_peaks[mask] jittered_channels = jittered_channels[mask] - peak_sample_ind, unique_idx = np.unique(jittered_peaks, return_index=True) + peak_sample_index, unique_idx = np.unique(jittered_peaks, return_index=True) peak_chan_ind = jittered_channels[unique_idx] else: - peak_sample_ind, unique_idx = np.unique(peak_sample_ind, return_index=True) + peak_sample_index, unique_idx = np.unique(peak_sample_index, return_index=True) peak_chan_ind = peak_chan_ind[unique_idx] - num_peaks = len(peak_sample_ind) + num_peaks = len(peak_sample_index) if sym_patch: - snippets = extract_patches_2d(traces, patch_sizes)[peak_sample_ind] - peak_sample_ind += margin // 2 + snippets = extract_patches_2d(traces, patch_sizes)[peak_sample_index] + peak_sample_index += margin // 2 else: - peak_sample_ind += margin // 2 + peak_sample_index += margin // 2 snippet_window = np.arange(-d["nbefore"], d["nafter"]) - snippets = traces[peak_sample_ind[:, np.newaxis] + snippet_window] + snippets = traces[peak_sample_index[:, np.newaxis] + snippet_window] if num_peaks > 0: snippets = snippets.reshape(num_peaks, -1) @@ -865,10 +784,10 @@ def main_function(cls, traces, d): best_cluster_ind, peak_index = np.unravel_index(idx_lookup[is_valid][best_amplitude_ind], idx_lookup.shape) best_amplitude = scalar_products[best_cluster_ind, peak_index] - best_peak_sample_ind = peak_sample_ind[peak_index] + best_peak_sample_index = peak_sample_index[peak_index] best_peak_chan_ind = peak_chan_ind[peak_index] - peak_data = peak_sample_ind - peak_sample_ind[peak_index] + peak_data = peak_sample_index - peak_sample_index[peak_index] is_valid_nn = np.searchsorted(peak_data, [-neighbor_window, neighbor_window + 1]) idx_neighbor = peak_data[is_valid_nn[0] : is_valid_nn[1]] + neighbor_window @@ -880,7 +799,7 @@ def main_function(cls, traces, d): scalar_products[:, is_valid_nn[0] : is_valid_nn[1]] += to_add scalar_products[best_cluster_ind, is_valid_nn[0] : is_valid_nn[1]] = -np.inf - spikes["sample_index"][num_spikes] = best_peak_sample_ind + spikes["sample_index"][num_spikes] = best_peak_sample_index spikes["channel_index"][num_spikes] = best_peak_chan_ind spikes["cluster_index"][num_spikes] = best_cluster_ind spikes["amplitude"][num_spikes] = best_amplitude From 0f9fee6fe788a0cdc44c18d19fd8b0f11f10ff4f Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Mon, 17 Jul 2023 10:30:33 +0200 Subject: [PATCH 02/22] WIP --- .../sorters/internal/spyking_circus2.py | 59 ++++++++++--------- .../clustering/clustering_tools.py | 2 +- .../clustering/random_projections.py | 3 +- 3 files changed, 35 insertions(+), 29 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 24c4a7ccfc..18db5f37c8 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -3,7 +3,7 @@ import os import shutil import numpy as np -import os +import psutil from spikeinterface.core import NumpySorting, load_extractor, BaseRecording, get_noise_levels, extract_waveforms from spikeinterface.core.job_tools import fix_job_kwargs @@ -18,23 +18,24 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): - sorter_name = "spykingcircus2" + sorter_name = 'spykingcircus2' _default_params = { - "general": {"ms_before": 2, "ms_after": 2, "local_radius_um": 100}, - "waveforms": {"max_spikes_per_unit": 200, "overwrite": True}, - "filtering": {"dtype": "float32"}, - "detection": {"peak_sign": "neg", "detect_threshold": 5}, - "selection": {"n_peaks_per_channel": 5000, "min_n_peaks": 20000}, - "localization": {}, - "clustering": {}, - "matching": {}, - "registration": {}, - "apply_preprocessing": True, - "shared_memory": False, - "job_kwargs": {}, + 'general' : {'ms_before' : 2, 'ms_after' : 2, 'local_radius_um' : 75}, + 'waveforms' : {'max_spikes_per_unit' : 200, 'overwrite' : True, 'sparse' : True, + 'method' : 'ptp', 'threshold' : 1}, + 'filtering' : {'dtype' : 'float32'}, + 'detection' : {'peak_sign': 'neg', 'detect_threshold': 5}, + 'selection' : {'n_peaks_per_channel' : 5000, 'min_n_peaks' : 20000}, + 'localization' : {}, + 'clustering': {}, + 'matching': {}, + 'apply_preprocessing': True, + 'shared_memory' : True, + 'job_kwargs' : {'n_jobs' : -1, 'chunk_memory' : "10M"} } + @classmethod def get_sorter_version(cls): return "2.0" @@ -63,8 +64,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ## First, we are filtering the data filtering_params = params["filtering"].copy() if params["apply_preprocessing"]: - # if recording.is_filtered == True: - # print('Looks like the recording is already filtered, check preprocessing!') recording_f = bandpass_filter(recording, **filtering_params) recording_f = common_reference(recording_f) else: @@ -102,12 +101,15 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ## We launch a clustering (using hdbscan) relying on positions and features extracted on ## the fly from the snippets - clustering_params = params["clustering"].copy() - clustering_params.update(params["waveforms"]) - clustering_params.update(params["general"]) - clustering_params.update(dict(shared_memory=params["shared_memory"])) - clustering_params["job_kwargs"] = job_kwargs - clustering_params["tmp_folder"] = sorter_output_folder / "clustering" + clustering_params = params['clustering'].copy() + clustering_params['waveforms_kwargs'] = params['waveforms'] + + for k in ['ms_before', 'ms_after']: + clustering_params['waveforms_kwargs'][k] = params['general'][k] + + clustering_params.update(dict(shared_memory=params['shared_memory'])) + clustering_params['job_kwargs'] = job_kwargs + clustering_params['tmp_folder'] = sorter_output_folder / "clustering" labels, peak_labels = find_cluster_from_peaks( recording_f, selected_peaks, method="random_projections", method_kwargs=clustering_params @@ -122,15 +124,18 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): sorting = sorting.save(folder=clustering_folder) - ## We get the templates our of such a clustering - waveforms_params = params["waveforms"].copy() + ## We get the templates our of such a clustering + waveforms_params = params['waveforms'].copy() waveforms_params.update(job_kwargs) - if params["shared_memory"]: - mode = "memory" + for k in ['ms_before', 'ms_after']: + waveforms_params[k] = params['general'][k] + + if params['shared_memory']: + mode = 'memory' waveforms_folder = None else: - mode = "folder" + mode = 'folder' waveforms_folder = sorter_output_folder / "waveforms" we = extract_waveforms( diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 53833b01a2..6edf5af16b 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -579,7 +579,7 @@ def remove_duplicates_via_matching( f.write(blanck) f.close() - recording = BinaryRecordingExtractor(tmp_filename, num_chan=num_chans, sampling_frequency=fs, dtype="float32") + recording = BinaryRecordingExtractor(tmp_filename, num_channels=num_chans, sampling_frequency=fs, dtype="float32") recording.annotate(is_filtered=True) margin = 2 * max(waveform_extractor.nbefore, waveform_extractor.nafter) diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 02247dd288..1450ba91db 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -238,7 +238,8 @@ def main_function(cls, recording, peaks, params): if params["tmp_folder"] is None: shutil.rmtree(tmp_folder) else: - shutil.rmtree(tmp_folder / "waveforms") + if not params["shared_memory"]: + shutil.rmtree(tmp_folder / "waveforms") shutil.rmtree(tmp_folder / "sorting") if verbose: From 7a3d4c2181da06c4106d6c17a015839a0cc55f4f Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Mon, 17 Jul 2023 14:06:10 +0200 Subject: [PATCH 03/22] WIP --- .../sortingcomponents/matching/circus.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index 8f08aac9c5..d86dac97e2 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -194,7 +194,6 @@ class CircusOMPPeeler(BaseTemplateMatchingEngine): """ _default_params = { - "sparsify_threshold": 1, "amplitudes": [0.6, 2], "omp_min_sps": 0.1, "waveform_extractor": None, @@ -219,6 +218,7 @@ def _prepare_templates(cls, d): else: sparsity = waveform_extractor.sparsity.mask + print(sparsity.mean()) templates = waveform_extractor.get_all_templates(mode='median').copy() d['sparsities'] = {} @@ -226,10 +226,10 @@ def _prepare_templates(cls, d): d['norms'] = np.zeros(num_templates, dtype=np.float32) for count, unit_id in enumerate(waveform_extractor.sorting.unit_ids): - template = templates[count] + template = templates[count][:, sparsity[count]] d['sparsities'][count], = np.nonzero(sparsity[count]) d['norms'][count] = np.linalg.norm(template) - d['templates'][count] = template[:, d['sparsities'][count]]/d['norms'][count] + d['templates'][count] = template/d['norms'][count] return d @@ -269,8 +269,8 @@ def initialize_and_check_kwargs(cls, recording, kwargs): d['ignored_ids'] = np.array(d['ignored_ids']) omp_min_sps = d['omp_min_sps'] - nb_active_channels = np.array([len(d['sparsities'][count]) for count in range(d['num_templates'])]) - d['stop_criteria'] = omp_min_sps * np.sqrt(nb_active_channels * d['num_samples']) + #nb_active_channels = np.array([len(d['sparsities'][count]) for count in range(d['num_templates'])]) + d['stop_criteria'] = omp_min_sps * np.sqrt(d['noise_levels'].sum() * d['num_samples']) return d @@ -307,7 +307,7 @@ def main_function(cls, traces, d): min_amplitude, max_amplitude = d['amplitudes'] sparsities = d['sparsities'] ignored_ids = d['ignored_ids'] - stop_criteria = d['stop_criteria'][:, np.newaxis] + stop_criteria = d['stop_criteria'] vicinity = d['vicinity'] if 'cached_fft_kernels' not in d: @@ -356,7 +356,7 @@ def main_function(cls, traces, d): spikes = np.empty(scalar_products.size, dtype=spike_dtype) idx_lookup = np.arange(scalar_products.size).reshape(num_templates, -1) - M = np.zeros((100, 100), dtype=np.float32) + M = np.zeros((num_peaks, num_peaks), dtype=np.float32) all_selections = np.empty((2, scalar_products.size), dtype=np.int32) final_amplitudes = np.zeros(scalar_products.shape, dtype=np.float32) @@ -647,7 +647,7 @@ def initialize_and_check_kwargs(cls, recording, kwargs): default_parameters.update(kwargs) # assert isinstance(d['waveform_extractor'], WaveformExtractor) - for v in ["sparsify_threshold", "use_sparse_matrix_threshold"]: + for v in ["use_sparse_matrix_threshold"]: assert (default_parameters[v] >= 0) and (default_parameters[v] <= 1), f"{v} should be in [0, 1]" default_parameters["num_channels"] = default_parameters["waveform_extractor"].recording.get_num_channels() From 892305bef89b97454fcda956f39b81e3b7673d55 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Mon, 24 Jul 2023 12:01:54 +0200 Subject: [PATCH 04/22] WIP --- src/spikeinterface/sortingcomponents/matching/circus.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index d86dac97e2..d3d2c39836 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -218,7 +218,6 @@ def _prepare_templates(cls, d): else: sparsity = waveform_extractor.sparsity.mask - print(sparsity.mean()) templates = waveform_extractor.get_all_templates(mode='median').copy() d['sparsities'] = {} @@ -542,7 +541,7 @@ class CircusPeeler(BaseTemplateMatchingEngine): 'use_sparse_matrix_threshold' : 0.25, 'progess_bar_steps' : False, 'waveform_extractor': None, - 'sparse_kwargs' : {'method' : 'threshold', 'threshold' : 0.5, 'peak_sign' : 'both'} + 'sparse_kwargs' : {'method' : 'ptp', 'threshold' : 1} } @classmethod From 1cb122c040b256bd0073e798e96880e19bff6d59 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Mon, 28 Aug 2023 13:35:59 +0200 Subject: [PATCH 05/22] WIP for circus2 --- .../sortingcomponents/clustering/clustering_tools.py | 1 + src/spikeinterface/sortingcomponents/matching/circus.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 6edf5af16b..06e0b8ea96 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -581,6 +581,7 @@ def remove_duplicates_via_matching( recording = BinaryRecordingExtractor(tmp_filename, num_channels=num_chans, sampling_frequency=fs, dtype="float32") recording.annotate(is_filtered=True) + recording = recording.set_probe(waveform_extractor.recording.get_probe()) margin = 2 * max(waveform_extractor.nbefore, waveform_extractor.nafter) half_marging = margin // 2 diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index d3d2c39836..ef823316a2 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -559,6 +559,8 @@ def _prepare_templates(cls, d): if not waveform_extractor.is_sparse(): sparsity = compute_sparsity(waveform_extractor, **d['sparse_kwargs']).mask + else: + sparsity = waveform_extractor.sparsity.mask templates = waveform_extractor.get_all_templates(mode='median').copy() d['sparsities'] = {} From ef204dd83e9f6fe627b849619932c44c331e2306 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Mon, 28 Aug 2023 13:58:00 +0200 Subject: [PATCH 06/22] WIP --- .../clustering/clustering_tools.py | 13 +- .../clustering/random_projections.py | 131 +++++++----------- 2 files changed, 58 insertions(+), 86 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 06e0b8ea96..f93142152f 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -536,7 +536,6 @@ def remove_duplicates_via_matching( waveform_extractor, noise_levels, peak_labels, - sparsify_threshold=1, method_kwargs={}, job_kwargs={}, tmp_folder=None, @@ -552,6 +551,10 @@ def remove_duplicates_via_matching( from pathlib import Path job_kwargs = fix_job_kwargs(job_kwargs) + + if waveform_extractor.is_sparse(): + sparsity = waveform_extractor.sparsity.mask + templates = waveform_extractor.get_all_templates(mode="median").copy() nb_templates = len(templates) duration = waveform_extractor.nbefore + waveform_extractor.nafter @@ -559,9 +562,10 @@ def remove_duplicates_via_matching( fs = waveform_extractor.recording.get_sampling_frequency() num_chans = waveform_extractor.recording.get_num_channels() - for t in range(nb_templates): - is_silent = templates[t].ptp(0) < sparsify_threshold - templates[t, :, is_silent] = 0 + if waveform_extractor.is_sparse(): + for count, unit_id in enumerate(waveform_extractor.sorting.unit_ids): + templates[count][:, ~sparsity[count]] = 0 + zdata = templates.reshape(nb_templates, -1) @@ -598,7 +602,6 @@ def remove_duplicates_via_matching( "waveform_extractor": waveform_extractor, "noise_levels": noise_levels, "amplitudes": [0.95, 1.05], - "sparsify_threshold": sparsify_threshold, "omp_min_sps": 0.1, "templates": None, "overlaps": None, diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 0803763573..5e14fa4736 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -41,7 +41,6 @@ class RandomProjectionClustering: "ms_before": 1.5, "ms_after": 1.5, "random_seed": 42, - "cleaning_method": "matching", "shared_memory": False, "min_values": {"ptp": 0, "energy": 0}, "tmp_folder": None, @@ -160,87 +159,57 @@ def main_function(cls, recording, peaks, params): spikes["segment_index"] = peaks[mask]["segment_index"] spikes["unit_index"] = peak_labels[mask] - cleaning_method = params["cleaning_method"] - if verbose: - print("We found %d raw clusters, starting to clean with %s..." % (len(labels), cleaning_method)) - - if cleaning_method == "cosine": - wfs_arrays = extract_waveforms_to_buffers( - recording, - spikes, - labels, - nbefore, - nafter, - mode="shared_memory", - return_scaled=False, - folder=None, - dtype=recording.get_dtype(), - sparsity_mask=None, - copy=True, - **params["job_kwargs"], - ) - - labels, peak_labels = remove_duplicates( - wfs_arrays, noise_levels, peak_labels, num_samples, num_chans, **params["cleaning_kwargs"] - ) - - elif cleaning_method == "dip": - wfs_arrays = {} - for label in labels: - mask = label == peak_labels - wfs_arrays[label] = hdbscan_data[mask] - - labels, peak_labels = remove_duplicates_via_dip(wfs_arrays, peak_labels, **params["cleaning_kwargs"]) - - elif cleaning_method == "matching": - # create a tmp folder - if params["tmp_folder"] is None: - name = "".join(random.choices(string.ascii_uppercase + string.digits, k=8)) - tmp_folder = get_global_tmp_folder() / name - else: - tmp_folder = Path(params["tmp_folder"]) - - if params["shared_memory"]: - waveform_folder = None - mode = "memory" - else: - waveform_folder = tmp_folder / "waveforms" - mode = "folder" - - sorting_folder = tmp_folder / "sorting" - sorting = NumpySorting.from_times_labels(spikes["sample_index"], spikes["unit_index"], fs) - sorting = sorting.save(folder=sorting_folder) - we = extract_waveforms( - recording, - sorting, - waveform_folder, - ms_before=params["ms_before"], - ms_after=params["ms_after"], - **params["job_kwargs"], - return_scaled=False, - mode=mode, - ) - - cleaning_matching_params = params["job_kwargs"].copy() - cleaning_matching_params["chunk_duration"] = "100ms" - cleaning_matching_params["n_jobs"] = 1 - cleaning_matching_params["verbose"] = False - cleaning_matching_params["progress_bar"] = False - - cleaning_params = params["cleaning_kwargs"].copy() - cleaning_params["tmp_folder"] = tmp_folder - - labels, peak_labels = remove_duplicates_via_matching( - we, noise_levels, peak_labels, job_kwargs=cleaning_matching_params, **cleaning_params - ) - - if params["tmp_folder"] is None: - shutil.rmtree(tmp_folder) - else: - if not params["shared_memory"]: - shutil.rmtree(tmp_folder / "waveforms") - shutil.rmtree(tmp_folder / "sorting") + print("We found %d raw clusters, starting to clean with matching..." % (len(labels))) + + + # create a tmp folder + if params["tmp_folder"] is None: + name = "".join(random.choices(string.ascii_uppercase + string.digits, k=8)) + tmp_folder = get_global_tmp_folder() / name + else: + tmp_folder = Path(params["tmp_folder"]) + + if params["shared_memory"]: + waveform_folder = None + mode = "memory" + else: + waveform_folder = tmp_folder / "waveforms" + mode = "folder" + + sorting_folder = tmp_folder / "sorting" + sorting = NumpySorting.from_times_labels(spikes["sample_index"], spikes["unit_index"], fs) + sorting = sorting.save(folder=sorting_folder) + we = extract_waveforms( + recording, + sorting, + waveform_folder, + ms_before=params["ms_before"], + ms_after=params["ms_after"], + **params["job_kwargs"], + return_scaled=False, + mode=mode, + ) + + cleaning_matching_params = params["job_kwargs"].copy() + cleaning_matching_params["chunk_duration"] = "100ms" + cleaning_matching_params["n_jobs"] = 1 + cleaning_matching_params["verbose"] = False + cleaning_matching_params["progress_bar"] = False + + cleaning_params = params["cleaning_kwargs"].copy() + cleaning_params["tmp_folder"] = tmp_folder + + labels, peak_labels = remove_duplicates_via_matching( + we, noise_levels, peak_labels, job_kwargs=cleaning_matching_params, **cleaning_params + ) + + if params["tmp_folder"] is None: + shutil.rmtree(tmp_folder) + else: + if not params["shared_memory"]: + shutil.rmtree(tmp_folder / "waveforms") + shutil.rmtree(tmp_folder / "sorting") if verbose: print("We kept %d non-duplicated clusters..." % len(labels)) From 242799ff582d886ad8438b9344eea594e07324af Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Mon, 28 Aug 2023 14:02:05 +0200 Subject: [PATCH 07/22] Docs --- .../sortingcomponents/matching/circus.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index ef823316a2..50058ab39e 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -5,7 +5,6 @@ import scipy.spatial -from tqdm import tqdm import scipy try: @@ -190,6 +189,9 @@ class CircusOMPPeeler(BaseTemplateMatchingEngine): computed random_chunk_kwargs: dict Parameters for computing noise levels, if not provided (sub optimal) + sparse_kwargs: dict + Parameters to extract a sparsity mask from the waveform_extractor, if not + already sparse. ----- """ @@ -522,8 +524,9 @@ class CircusPeeler(BaseTemplateMatchingEngine): use_sparse_matrix_threshold: float If density of the templates is below a given threshold, sparse matrix are used (memory efficient) - progress_bar_steps: bool - In order to display or not steps from the algorithm + sparse_kwargs: dict + Parameters to extract a sparsity mask from the waveform_extractor, if not + already sparse. ----- @@ -539,7 +542,6 @@ class CircusPeeler(BaseTemplateMatchingEngine): 'max_amplitude' : 1.5, 'min_amplitude' : 0.5, 'use_sparse_matrix_threshold' : 0.25, - 'progess_bar_steps' : False, 'waveform_extractor': None, 'sparse_kwargs' : {'method' : 'ptp', 'threshold' : 1} } @@ -618,8 +620,6 @@ def _optimize_amplitudes(cls, noise_snippets, d): alpha = 0.5 norms = parameters["norms"] all_units = list(waveform_extractor.sorting.unit_ids) - if parameters["progess_bar_steps"]: - all_units = tqdm(all_units, desc="[2] compute amplitudes") parameters["amplitudes"] = np.zeros((num_templates, 2), dtype=np.float32) noise = templates.dot(noise_snippets) / norms[:, np.newaxis] From 5566c917ddbd32feda022e4293ba0bc93bdd3139 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 29 Aug 2023 08:46:28 +0200 Subject: [PATCH 08/22] Fix for circus --- .../sortingcomponents/matching/circus.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index 50058ab39e..f79cf60a31 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -357,7 +357,7 @@ def main_function(cls, traces, d): spikes = np.empty(scalar_products.size, dtype=spike_dtype) idx_lookup = np.arange(scalar_products.size).reshape(num_templates, -1) - M = np.zeros((num_peaks, num_peaks), dtype=np.float32) + M = np.zeros((100, 100), dtype=np.float32) all_selections = np.empty((2, scalar_products.size), dtype=np.int32) final_amplitudes = np.zeros(scalar_products.shape, dtype=np.float32) @@ -570,7 +570,7 @@ def _prepare_templates(cls, d): for count, unit_id in enumerate(all_units): d['sparsities'][count], = np.nonzero(sparsity[count]) - templates[count][sparsity[count] == False] = 0 + templates[count][:, ~sparsity[count]] = 0 d['norms'][count] = np.linalg.norm(templates[count]) templates[count] /= d['norms'][count] @@ -666,7 +666,15 @@ def initialize_and_check_kwargs(cls, recording, kwargs): ) default_parameters = cls._prepare_templates(default_parameters) - default_parameters = cls._prepare_overlaps(default_parameters) + + templates = default_parameters['templates'].reshape(len(default_parameters['templates']), + default_parameters['num_samples'], + default_parameters['num_channels']) + + default_parameters['overlaps'] = compute_overlaps(templates, + default_parameters['num_samples'], + default_parameters['num_channels'], + default_parameters['sparsities']) default_parameters["exclude_sweep_size"] = int( default_parameters["exclude_sweep_ms"] * recording.get_sampling_frequency() / 1000.0 From 75c97937c1f5f66714076dba237574eddbb9782c Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 29 Aug 2023 09:12:16 +0200 Subject: [PATCH 09/22] WIP --- src/spikeinterface/sortingcomponents/matching/circus.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index f79cf60a31..baf7494002 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -432,13 +432,14 @@ def main_function(cls, traces, d): selection = all_selections[:, :num_selection] res_sps = full_sps[selection[0], selection[1]] - if vicinity == 0: + if True: #vicinity == 0: all_amplitudes, _ = potrs(M[:num_selection, :num_selection], res_sps, lower=True, overwrite_b=False) all_amplitudes /= norms[selection[0]] else: + # This is not working, need to figure out why is_in_vicinity = np.append(is_in_vicinity, num_selection - 1) - all_amplitudes = np.append(all_amplitudes, np.float32(0)) + all_amplitudes = np.append(all_amplitudes, np.float32(1)) L = M[is_in_vicinity, :][:, is_in_vicinity] all_amplitudes[is_in_vicinity], _ = potrs(L, res_sps[is_in_vicinity], lower=True, overwrite_b=False) From d7e9ac1c803121b7e0fb0d8c4af539340fb82bbe Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 29 Aug 2023 07:14:41 +0000 Subject: [PATCH 10/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sorters/internal/spyking_circus2.py | 56 ++-- .../clustering/clustering_tools.py | 1 - .../clustering/random_projections.py | 1 - .../sortingcomponents/matching/circus.py | 286 +++++++++--------- 4 files changed, 166 insertions(+), 178 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 6635bbfca1..4ccaef8e29 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -18,24 +18,22 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): - sorter_name = 'spykingcircus2' + sorter_name = "spykingcircus2" _default_params = { - 'general' : {'ms_before' : 2, 'ms_after' : 2, 'radius_um' : 75}, - 'waveforms' : {'max_spikes_per_unit' : 200, 'overwrite' : True, 'sparse' : True, - 'method' : 'ptp', 'threshold' : 1}, - 'filtering' : {'dtype' : 'float32'}, - 'detection' : {'peak_sign': 'neg', 'detect_threshold': 5}, - 'selection' : {'n_peaks_per_channel' : 5000, 'min_n_peaks' : 20000}, - 'localization' : {}, - 'clustering': {}, - 'matching': {}, - 'apply_preprocessing': True, - 'shared_memory' : True, - 'job_kwargs' : {'n_jobs' : -1, 'chunk_memory' : "10M"} + "general": {"ms_before": 2, "ms_after": 2, "radius_um": 75}, + "waveforms": {"max_spikes_per_unit": 200, "overwrite": True, "sparse": True, "method": "ptp", "threshold": 1}, + "filtering": {"dtype": "float32"}, + "detection": {"peak_sign": "neg", "detect_threshold": 5}, + "selection": {"n_peaks_per_channel": 5000, "min_n_peaks": 20000}, + "localization": {}, + "clustering": {}, + "matching": {}, + "apply_preprocessing": True, + "shared_memory": True, + "job_kwargs": {"n_jobs": -1, "chunk_memory": "10M"}, } - @classmethod def get_sorter_version(cls): return "2.0" @@ -101,15 +99,15 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ## We launch a clustering (using hdbscan) relying on positions and features extracted on ## the fly from the snippets - clustering_params = params['clustering'].copy() - clustering_params['waveforms_kwargs'] = params['waveforms'] - - for k in ['ms_before', 'ms_after']: - clustering_params['waveforms_kwargs'][k] = params['general'][k] + clustering_params = params["clustering"].copy() + clustering_params["waveforms_kwargs"] = params["waveforms"] + + for k in ["ms_before", "ms_after"]: + clustering_params["waveforms_kwargs"][k] = params["general"][k] - clustering_params.update(dict(shared_memory=params['shared_memory'])) - clustering_params['job_kwargs'] = job_kwargs - clustering_params['tmp_folder'] = sorter_output_folder / "clustering" + clustering_params.update(dict(shared_memory=params["shared_memory"])) + clustering_params["job_kwargs"] = job_kwargs + clustering_params["tmp_folder"] = sorter_output_folder / "clustering" labels, peak_labels = find_cluster_from_peaks( recording_f, selected_peaks, method="random_projections", method_kwargs=clustering_params @@ -124,18 +122,18 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): sorting = sorting.save(folder=clustering_folder) - ## We get the templates our of such a clustering - waveforms_params = params['waveforms'].copy() + ## We get the templates our of such a clustering + waveforms_params = params["waveforms"].copy() waveforms_params.update(job_kwargs) - for k in ['ms_before', 'ms_after']: - waveforms_params[k] = params['general'][k] + for k in ["ms_before", "ms_after"]: + waveforms_params[k] = params["general"][k] - if params['shared_memory']: - mode = 'memory' + if params["shared_memory"]: + mode = "memory" waveforms_folder = None else: - mode = 'folder' + mode = "folder" waveforms_folder = sorter_output_folder / "waveforms" we = extract_waveforms( diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index f93142152f..b11af55d35 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -565,7 +565,6 @@ def remove_duplicates_via_matching( if waveform_extractor.is_sparse(): for count, unit_id in enumerate(waveform_extractor.sorting.unit_ids): templates[count][:, ~sparsity[count]] = 0 - zdata = templates.reshape(nb_templates, -1) diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 5e14fa4736..ac564bda9a 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -162,7 +162,6 @@ def main_function(cls, recording, peaks, params): if verbose: print("We found %d raw clusters, starting to clean with matching..." % (len(labels))) - # create a tmp folder if params["tmp_folder"] is None: name = "".join(random.choices(string.ascii_uppercase + string.digits, k=8)) diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index baf7494002..b0f132e94d 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -130,8 +130,8 @@ def _freq_domain_conv(in1, in2, axes, shape, cache, calc_fast_len=True): return ret -def compute_overlaps(templates, num_samples, num_channels, sparsities): +def compute_overlaps(templates, num_samples, num_channels, sparsities): num_templates = len(templates) dense_templates = np.zeros((num_templates, num_samples, num_channels), dtype=np.float32) @@ -140,13 +140,13 @@ def compute_overlaps(templates, num_samples, num_channels, sparsities): size = 2 * num_samples - 1 - all_delays = list(range(0, num_samples+1)) + all_delays = list(range(0, num_samples + 1)) overlaps = {} - + for delay in all_delays: source = dense_templates[:, :delay, :].reshape(num_templates, -1) - target = dense_templates[:, num_samples-delay:, :].reshape(num_templates, -1) + target = dense_templates[:, num_samples - delay :, :].reshape(num_templates, -1) overlaps[delay] = scipy.sparse.csr_matrix(source.dot(target.T)) @@ -161,7 +161,7 @@ def compute_overlaps(templates, num_samples, num_channels, sparsities): new_overlaps += [data] return new_overlaps - + class CircusOMPPeeler(BaseTemplateMatchingEngine): """ @@ -204,77 +204,74 @@ class CircusOMPPeeler(BaseTemplateMatchingEngine): "norms": None, "random_chunk_kwargs": {}, "noise_levels": None, - 'sparse_kwargs' : {'method' : 'ptp', 'threshold' : 1}, + "sparse_kwargs": {"method": "ptp", "threshold": 1}, "ignored_ids": [], - "vicinity" : 0 + "vicinity": 0, } @classmethod def _prepare_templates(cls, d): - - waveform_extractor = d['waveform_extractor'] - num_templates = len(d['waveform_extractor'].sorting.unit_ids) + waveform_extractor = d["waveform_extractor"] + num_templates = len(d["waveform_extractor"].sorting.unit_ids) if not waveform_extractor.is_sparse(): - sparsity = compute_sparsity(waveform_extractor, **d['sparse_kwargs']).mask + sparsity = compute_sparsity(waveform_extractor, **d["sparse_kwargs"]).mask else: sparsity = waveform_extractor.sparsity.mask - - templates = waveform_extractor.get_all_templates(mode='median').copy() - d['sparsities'] = {} - d['templates'] = {} - d['norms'] = np.zeros(num_templates, dtype=np.float32) + templates = waveform_extractor.get_all_templates(mode="median").copy() + + d["sparsities"] = {} + d["templates"] = {} + d["norms"] = np.zeros(num_templates, dtype=np.float32) for count, unit_id in enumerate(waveform_extractor.sorting.unit_ids): template = templates[count][:, sparsity[count]] - d['sparsities'][count], = np.nonzero(sparsity[count]) - d['norms'][count] = np.linalg.norm(template) - d['templates'][count] = template/d['norms'][count] + (d["sparsities"][count],) = np.nonzero(sparsity[count]) + d["norms"][count] = np.linalg.norm(template) + d["templates"][count] = template / d["norms"][count] return d @classmethod def initialize_and_check_kwargs(cls, recording, kwargs): - d = cls._default_params.copy() d.update(kwargs) - #assert isinstance(d['waveform_extractor'], WaveformExtractor) + # assert isinstance(d['waveform_extractor'], WaveformExtractor) + + for v in ["omp_min_sps"]: + assert (d[v] >= 0) and (d[v] <= 1), f"{v} should be in [0, 1]" - for v in ['omp_min_sps']: - assert (d[v] >= 0) and (d[v] <= 1), f'{v} should be in [0, 1]' - - d['num_channels'] = d['waveform_extractor'].recording.get_num_channels() - d['num_samples'] = d['waveform_extractor'].nsamples - d['nbefore'] = d['waveform_extractor'].nbefore - d['nafter'] = d['waveform_extractor'].nafter - d['sampling_frequency'] = d['waveform_extractor'].recording.get_sampling_frequency() - d['vicinity'] *= d['num_samples'] + d["num_channels"] = d["waveform_extractor"].recording.get_num_channels() + d["num_samples"] = d["waveform_extractor"].nsamples + d["nbefore"] = d["waveform_extractor"].nbefore + d["nafter"] = d["waveform_extractor"].nafter + d["sampling_frequency"] = d["waveform_extractor"].recording.get_sampling_frequency() + d["vicinity"] *= d["num_samples"] - if d['noise_levels'] is None: - print('CircusOMPPeeler : noise should be computed outside') - d['noise_levels'] = get_noise_levels(recording, **d['random_chunk_kwargs'], return_scaled=False) + if d["noise_levels"] is None: + print("CircusOMPPeeler : noise should be computed outside") + d["noise_levels"] = get_noise_levels(recording, **d["random_chunk_kwargs"], return_scaled=False) - if d['templates'] is None: + if d["templates"] is None: d = cls._prepare_templates(d) else: - for key in ['norms', 'sparsities']: - assert d[key] is not None, "If templates are provided, %d should also be there" %key - - d['num_templates'] = len(d['templates']) + for key in ["norms", "sparsities"]: + assert d[key] is not None, "If templates are provided, %d should also be there" % key - if d['overlaps'] is None: - d['overlaps'] = compute_overlaps(d['templates'], d['num_samples'], d['num_channels'], d['sparsities']) + d["num_templates"] = len(d["templates"]) - d['ignored_ids'] = np.array(d['ignored_ids']) + if d["overlaps"] is None: + d["overlaps"] = compute_overlaps(d["templates"], d["num_samples"], d["num_channels"], d["sparsities"]) - omp_min_sps = d['omp_min_sps'] - #nb_active_channels = np.array([len(d['sparsities'][count]) for count in range(d['num_templates'])]) - d['stop_criteria'] = omp_min_sps * np.sqrt(d['noise_levels'].sum() * d['num_samples']) + d["ignored_ids"] = np.array(d["ignored_ids"]) - return d + omp_min_sps = d["omp_min_sps"] + # nb_active_channels = np.array([len(d['sparsities'][count]) for count in range(d['num_templates'])]) + d["stop_criteria"] = omp_min_sps * np.sqrt(d["noise_levels"].sum() * d["num_samples"]) + return d @classmethod def serialize_method_kwargs(cls, kwargs): @@ -294,27 +291,27 @@ def get_margin(cls, recording, kwargs): @classmethod def main_function(cls, traces, d): - templates = d['templates'] - num_templates = d['num_templates'] - num_channels = d['num_channels'] - num_samples = d['num_samples'] - overlaps = d['overlaps'] - norms = d['norms'] - nbefore = d['nbefore'] - nafter = d['nafter'] + templates = d["templates"] + num_templates = d["num_templates"] + num_channels = d["num_channels"] + num_samples = d["num_samples"] + overlaps = d["overlaps"] + norms = d["norms"] + nbefore = d["nbefore"] + nafter = d["nafter"] omp_tol = np.finfo(np.float32).eps - num_samples = d['nafter'] + d['nbefore'] + num_samples = d["nafter"] + d["nbefore"] neighbor_window = num_samples - 1 - min_amplitude, max_amplitude = d['amplitudes'] - sparsities = d['sparsities'] - ignored_ids = d['ignored_ids'] - stop_criteria = d['stop_criteria'] - vicinity = d['vicinity'] + min_amplitude, max_amplitude = d["amplitudes"] + sparsities = d["sparsities"] + ignored_ids = d["ignored_ids"] + stop_criteria = d["stop_criteria"] + vicinity = d["vicinity"] - if 'cached_fft_kernels' not in d: - d['cached_fft_kernels'] = {'fshape' : 0} + if "cached_fft_kernels" not in d: + d["cached_fft_kernels"] = {"fshape": 0} - cached_fft_kernels = d['cached_fft_kernels'] + cached_fft_kernels = d["cached_fft_kernels"] num_timesteps = len(traces) @@ -326,24 +323,22 @@ def main_function(cls, traces, d): dummy_traces = np.empty((num_channels, num_timesteps), dtype=np.float32) fshape, axes = get_scipy_shape(dummy_filter, traces, axes=1) - fft_cache = {'full' : sp_fft.rfftn(traces, fshape, axes=axes)} + fft_cache = {"full": sp_fft.rfftn(traces, fshape, axes=axes)} scalar_products = np.empty((num_templates, num_peaks), dtype=np.float32) - flagged_chunk = cached_fft_kernels['fshape'] != fshape[0] + flagged_chunk = cached_fft_kernels["fshape"] != fshape[0] for i in range(num_templates): - if i not in ignored_ids: - if i not in cached_fft_kernels or flagged_chunk: kernel_filter = np.ascontiguousarray(templates[i][::-1].T) - cached_fft_kernels.update({i : sp_fft.rfftn(kernel_filter, fshape, axes=axes)}) - cached_fft_kernels['fshape'] = fshape[0] + cached_fft_kernels.update({i: sp_fft.rfftn(kernel_filter, fshape, axes=axes)}) + cached_fft_kernels["fshape"] = fshape[0] - fft_cache.update({'mask' : sparsities[i], 'template' : cached_fft_kernels[i]}) + fft_cache.update({"mask": sparsities[i], "template": cached_fft_kernels[i]}) - convolution = fftconvolve_with_cache(dummy_filter, dummy_traces, fft_cache, axes=1, mode='valid') + convolution = fftconvolve_with_cache(dummy_filter, dummy_traces, fft_cache, axes=1, mode="valid") if len(convolution) > 0: scalar_products[i] = convolution.sum(0) else: @@ -368,17 +363,15 @@ def main_function(cls, traces, d): neighbors = {} cached_overlaps = {} - is_valid = (scalar_products > stop_criteria) + is_valid = scalar_products > stop_criteria all_amplitudes = np.zeros(0, dtype=np.float32) is_in_vicinity = np.zeros(0, dtype=np.int32) while np.any(is_valid): - best_amplitude_ind = scalar_products[is_valid].argmax() best_cluster_ind, peak_index = np.unravel_index(idx_lookup[is_valid][best_amplitude_ind], idx_lookup.shape) - - if num_selection > 0: + if num_selection > 0: delta_t = selection[1] - peak_index idx = np.where((delta_t < neighbor_window) & (delta_t > -num_samples))[0] myline = num_samples + delta_t[idx] @@ -387,17 +380,21 @@ def main_function(cls, traces, d): cached_overlaps[best_cluster_ind] = overlaps[best_cluster_ind].toarray() if num_selection == M.shape[0]: - Z = np.zeros((2*num_selection, 2*num_selection), dtype=np.float32) + Z = np.zeros((2 * num_selection, 2 * num_selection), dtype=np.float32) Z[:num_selection, :num_selection] = M M = Z M[num_selection, idx] = cached_overlaps[best_cluster_ind][selection[0, idx], myline] if vicinity == 0: - scipy.linalg.solve_triangular(M[:num_selection, :num_selection], M[num_selection, :num_selection], trans=0, - lower=1, - overwrite_b=True, - check_finite=False) + scipy.linalg.solve_triangular( + M[:num_selection, :num_selection], + M[num_selection, :num_selection], + trans=0, + lower=1, + overwrite_b=True, + check_finite=False, + ) v = nrm2(M[num_selection, :num_selection]) ** 2 Lkk = 1 - v @@ -408,13 +405,11 @@ def main_function(cls, traces, d): is_in_vicinity = np.where(np.abs(delta_t) < vicinity)[0] if len(is_in_vicinity) > 0: - L = M[is_in_vicinity, :][:, is_in_vicinity] - M[num_selection, is_in_vicinity] = scipy.linalg.solve_triangular(L, M[num_selection, is_in_vicinity], trans=0, - lower=1, - overwrite_b=True, - check_finite=False) + M[num_selection, is_in_vicinity] = scipy.linalg.solve_triangular( + L, M[num_selection, is_in_vicinity], trans=0, lower=1, overwrite_b=True, check_finite=False + ) v = nrm2(M[num_selection, is_in_vicinity]) ** 2 Lkk = 1 - v @@ -432,55 +427,52 @@ def main_function(cls, traces, d): selection = all_selections[:, :num_selection] res_sps = full_sps[selection[0], selection[1]] - if True: #vicinity == 0: - all_amplitudes, _ = potrs(M[:num_selection, :num_selection], res_sps, - lower=True, overwrite_b=False) + if True: # vicinity == 0: + all_amplitudes, _ = potrs(M[:num_selection, :num_selection], res_sps, lower=True, overwrite_b=False) all_amplitudes /= norms[selection[0]] else: # This is not working, need to figure out why is_in_vicinity = np.append(is_in_vicinity, num_selection - 1) all_amplitudes = np.append(all_amplitudes, np.float32(1)) L = M[is_in_vicinity, :][:, is_in_vicinity] - all_amplitudes[is_in_vicinity], _ = potrs(L, res_sps[is_in_vicinity], - lower=True, overwrite_b=False) + all_amplitudes[is_in_vicinity], _ = potrs(L, res_sps[is_in_vicinity], lower=True, overwrite_b=False) all_amplitudes[is_in_vicinity] /= norms[selection[0][is_in_vicinity]] - diff_amplitudes = (all_amplitudes - final_amplitudes[selection[0], selection[1]]) + diff_amplitudes = all_amplitudes - final_amplitudes[selection[0], selection[1]] modified = np.where(np.abs(diff_amplitudes) > omp_tol)[0] final_amplitudes[selection[0], selection[1]] = all_amplitudes for i in modified: - tmp_best, tmp_peak = selection[:, i] - diff_amp = diff_amplitudes[i]*norms[tmp_best] - + diff_amp = diff_amplitudes[i] * norms[tmp_best] + if not tmp_best in cached_overlaps: cached_overlaps[tmp_best] = overlaps[tmp_best].toarray() if not tmp_peak in neighbors.keys(): idx = [max(0, tmp_peak - num_samples), min(num_peaks, tmp_peak + neighbor_window)] tdx = [num_samples + idx[0] - tmp_peak, num_samples + idx[1] - tmp_peak] - neighbors[tmp_peak] = {'idx' : idx, 'tdx' : tdx} + neighbors[tmp_peak] = {"idx": idx, "tdx": tdx} - idx = neighbors[tmp_peak]['idx'] - tdx = neighbors[tmp_peak]['tdx'] + idx = neighbors[tmp_peak]["idx"] + tdx = neighbors[tmp_peak]["tdx"] - to_add = diff_amp * cached_overlaps[tmp_best][:, tdx[0]:tdx[1]] - scalar_products[:, idx[0]:idx[1]] -= to_add + to_add = diff_amp * cached_overlaps[tmp_best][:, tdx[0] : tdx[1]] + scalar_products[:, idx[0] : idx[1]] -= to_add - is_valid = (scalar_products > stop_criteria) + is_valid = scalar_products > stop_criteria - is_valid = (final_amplitudes > min_amplitude)*(final_amplitudes < max_amplitude) + is_valid = (final_amplitudes > min_amplitude) * (final_amplitudes < max_amplitude) valid_indices = np.where(is_valid) num_spikes = len(valid_indices[0]) - spikes['sample_index'][:num_spikes] = valid_indices[1] + d['nbefore'] - spikes['channel_index'][:num_spikes] = 0 - spikes['cluster_index'][:num_spikes] = valid_indices[0] - spikes['amplitude'][:num_spikes] = final_amplitudes[valid_indices[0], valid_indices[1]] - + spikes["sample_index"][:num_spikes] = valid_indices[1] + d["nbefore"] + spikes["channel_index"][:num_spikes] = 0 + spikes["cluster_index"][:num_spikes] = valid_indices[0] + spikes["amplitude"][:num_spikes] = final_amplitudes[valid_indices[0], valid_indices[1]] + spikes = spikes[:num_spikes] - order = np.argsort(spikes['sample_index']) + order = np.argsort(spikes["sample_index"]) spikes = spikes[order] return spikes @@ -534,58 +526,56 @@ class CircusPeeler(BaseTemplateMatchingEngine): """ _default_params = { - 'peak_sign': 'neg', - 'exclude_sweep_ms': 0.1, - 'jitter_ms' : 0.1, - 'detect_threshold': 5, - 'noise_levels': None, - 'random_chunk_kwargs': {}, - 'max_amplitude' : 1.5, - 'min_amplitude' : 0.5, - 'use_sparse_matrix_threshold' : 0.25, - 'waveform_extractor': None, - 'sparse_kwargs' : {'method' : 'ptp', 'threshold' : 1} + "peak_sign": "neg", + "exclude_sweep_ms": 0.1, + "jitter_ms": 0.1, + "detect_threshold": 5, + "noise_levels": None, + "random_chunk_kwargs": {}, + "max_amplitude": 1.5, + "min_amplitude": 0.5, + "use_sparse_matrix_threshold": 0.25, + "waveform_extractor": None, + "sparse_kwargs": {"method": "ptp", "threshold": 1}, } @classmethod def _prepare_templates(cls, d): - - waveform_extractor = d['waveform_extractor'] - num_samples = d['num_samples'] - num_channels = d['num_channels'] - num_templates = d['num_templates'] - use_sparse_matrix_threshold = d['use_sparse_matrix_threshold'] + waveform_extractor = d["waveform_extractor"] + num_samples = d["num_samples"] + num_channels = d["num_channels"] + num_templates = d["num_templates"] + use_sparse_matrix_threshold = d["use_sparse_matrix_threshold"] - d['norms'] = np.zeros(num_templates, dtype=np.float32) + d["norms"] = np.zeros(num_templates, dtype=np.float32) - all_units = list(d['waveform_extractor'].sorting.unit_ids) + all_units = list(d["waveform_extractor"].sorting.unit_ids) if not waveform_extractor.is_sparse(): - sparsity = compute_sparsity(waveform_extractor, **d['sparse_kwargs']).mask + sparsity = compute_sparsity(waveform_extractor, **d["sparse_kwargs"]).mask else: sparsity = waveform_extractor.sparsity.mask - templates = waveform_extractor.get_all_templates(mode='median').copy() - d['sparsities'] = {} - - for count, unit_id in enumerate(all_units): + templates = waveform_extractor.get_all_templates(mode="median").copy() + d["sparsities"] = {} - d['sparsities'][count], = np.nonzero(sparsity[count]) + for count, unit_id in enumerate(all_units): + (d["sparsities"][count],) = np.nonzero(sparsity[count]) templates[count][:, ~sparsity[count]] = 0 - d['norms'][count] = np.linalg.norm(templates[count]) - templates[count] /= d['norms'][count] + d["norms"][count] = np.linalg.norm(templates[count]) + templates[count] /= d["norms"][count] templates = templates.reshape(num_templates, -1) - nnz = np.sum(templates != 0)/(num_templates * num_samples * num_channels) + nnz = np.sum(templates != 0) / (num_templates * num_samples * num_channels) if nnz <= use_sparse_matrix_threshold: templates = scipy.sparse.csr_matrix(templates) - print(f'Templates are automatically sparsified (sparsity level is {nnz})') - d['is_dense'] = False + print(f"Templates are automatically sparsified (sparsity level is {nnz})") + d["is_dense"] = False else: - d['is_dense'] = True + d["is_dense"] = True - d['templates'] = templates + d["templates"] = templates return d @@ -595,9 +585,9 @@ def _mcc_error(cls, bounds, good, bad): fp = np.sum((bounds[0] <= bad) & (bad <= bounds[1])) tp = np.sum((bounds[0] <= good) & (good <= bounds[1])) tn = np.sum((bad < bounds[0]) | (bad > bounds[1])) - denom = (tp+fp)*(tp+fn)*(tn+fp)*(tn+fn) + denom = (tp + fp) * (tp + fn) * (tn + fp) * (tn + fn) if denom > 0: - mcc = 1 - (tp*tn - fp*fn)/np.sqrt(denom) + mcc = 1 - (tp * tn - fp * fn) / np.sqrt(denom) else: mcc = 1 return mcc @@ -668,14 +658,16 @@ def initialize_and_check_kwargs(cls, recording, kwargs): default_parameters = cls._prepare_templates(default_parameters) - templates = default_parameters['templates'].reshape(len(default_parameters['templates']), - default_parameters['num_samples'], - default_parameters['num_channels']) + templates = default_parameters["templates"].reshape( + len(default_parameters["templates"]), default_parameters["num_samples"], default_parameters["num_channels"] + ) - default_parameters['overlaps'] = compute_overlaps(templates, - default_parameters['num_samples'], - default_parameters['num_channels'], - default_parameters['sparsities']) + default_parameters["overlaps"] = compute_overlaps( + templates, + default_parameters["num_samples"], + default_parameters["num_channels"], + default_parameters["sparsities"], + ) default_parameters["exclude_sweep_size"] = int( default_parameters["exclude_sweep_ms"] * recording.get_sampling_frequency() / 1000.0 From 14c8f58571fefc60eaa544da476c0210d45d2b92 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 29 Aug 2023 11:09:02 +0200 Subject: [PATCH 11/22] useless dependency --- src/spikeinterface/sorters/internal/spyking_circus2.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 4ccaef8e29..ec2a74b6bb 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -3,7 +3,6 @@ import os import shutil import numpy as np -import psutil from spikeinterface.core import NumpySorting, load_extractor, BaseRecording, get_noise_levels, extract_waveforms from spikeinterface.core.job_tools import fix_job_kwargs From e455da3f46cc5529986f60c56cb7868391f12af5 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 29 Aug 2023 13:51:38 +0200 Subject: [PATCH 12/22] Fix for classical circus with sparsity --- .../sortingcomponents/matching/circus.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index b0f132e94d..cdacfe1304 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -136,6 +136,7 @@ def compute_overlaps(templates, num_samples, num_channels, sparsities): dense_templates = np.zeros((num_templates, num_samples, num_channels), dtype=np.float32) for i in range(num_templates): + print(templates[i].shape, len(sparsities[i])) dense_templates[i, :, sparsities[i]] = templates[i].T size = 2 * num_samples - 1 @@ -558,12 +559,14 @@ def _prepare_templates(cls, d): templates = waveform_extractor.get_all_templates(mode="median").copy() d["sparsities"] = {} + d["circus_templates"] = {} for count, unit_id in enumerate(all_units): (d["sparsities"][count],) = np.nonzero(sparsity[count]) templates[count][:, ~sparsity[count]] = 0 d["norms"][count] = np.linalg.norm(templates[count]) templates[count] /= d["norms"][count] + d['circus_templates'][count] = templates[count][:, sparsity[count]] templates = templates.reshape(num_templates, -1) @@ -617,7 +620,7 @@ def _optimize_amplitudes(cls, noise_snippets, d): all_amps = {} for count, unit_id in enumerate(all_units): - waveform = waveform_extractor.get_waveforms(unit_id) + waveform = waveform_extractor.get_waveforms(unit_id, force_dense=True) snippets = waveform.reshape(waveform.shape[0], -1).T amps = templates.dot(snippets) / norms[:, np.newaxis] good = amps[count, :].flatten() @@ -658,12 +661,8 @@ def initialize_and_check_kwargs(cls, recording, kwargs): default_parameters = cls._prepare_templates(default_parameters) - templates = default_parameters["templates"].reshape( - len(default_parameters["templates"]), default_parameters["num_samples"], default_parameters["num_channels"] - ) - default_parameters["overlaps"] = compute_overlaps( - templates, + default_parameters['circus_templates'], default_parameters["num_samples"], default_parameters["num_channels"], default_parameters["sparsities"], From 2f84c6b632cd17391ba1eff0b89578b87f2fb892 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 29 Aug 2023 11:51:59 +0000 Subject: [PATCH 13/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sortingcomponents/matching/circus.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index cdacfe1304..e92e7929f6 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -566,7 +566,7 @@ def _prepare_templates(cls, d): templates[count][:, ~sparsity[count]] = 0 d["norms"][count] = np.linalg.norm(templates[count]) templates[count] /= d["norms"][count] - d['circus_templates'][count] = templates[count][:, sparsity[count]] + d["circus_templates"][count] = templates[count][:, sparsity[count]] templates = templates.reshape(num_templates, -1) @@ -662,7 +662,7 @@ def initialize_and_check_kwargs(cls, recording, kwargs): default_parameters = cls._prepare_templates(default_parameters) default_parameters["overlaps"] = compute_overlaps( - default_parameters['circus_templates'], + default_parameters["circus_templates"], default_parameters["num_samples"], default_parameters["num_channels"], default_parameters["sparsities"], From 3d849fb91680f05c27c52dc240f61e65490c4a16 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 29 Aug 2023 13:52:34 +0200 Subject: [PATCH 14/22] Fix for classical circus with sparsity --- src/spikeinterface/sortingcomponents/matching/circus.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index cdacfe1304..06cd99d92a 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -136,7 +136,6 @@ def compute_overlaps(templates, num_samples, num_channels, sparsities): dense_templates = np.zeros((num_templates, num_samples, num_channels), dtype=np.float32) for i in range(num_templates): - print(templates[i].shape, len(sparsities[i])) dense_templates[i, :, sparsities[i]] = templates[i].T size = 2 * num_samples - 1 From 7dcfdb0b325ffefb980c54ac5070339a490f8b49 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 29 Aug 2023 14:56:23 +0200 Subject: [PATCH 15/22] Fixing slow tests with SC2 --- src/spikeinterface/sorters/internal/spyking_circus2.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index ec2a74b6bb..628ea991c1 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -30,7 +30,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "matching": {}, "apply_preprocessing": True, "shared_memory": True, - "job_kwargs": {"n_jobs": -1, "chunk_memory": "10M"}, + "job_kwargs": {"n_jobs": -1}, } @classmethod @@ -145,6 +145,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): matching_params.update({"noise_levels": noise_levels}) matching_job_params = job_kwargs.copy() + if 'chunk_memory' in matching_job_params: + matching_job_params.pop('chunk_memory') + matching_job_params["chunk_duration"] = "100ms" spikes = find_spikes_from_templates( From 9f196b58acf4a5d2cc1ebc45a0ee969c03451d83 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 29 Aug 2023 12:58:57 +0000 Subject: [PATCH 16/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sorters/internal/spyking_circus2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 628ea991c1..8a7b353bd1 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -145,8 +145,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): matching_params.update({"noise_levels": noise_levels}) matching_job_params = job_kwargs.copy() - if 'chunk_memory' in matching_job_params: - matching_job_params.pop('chunk_memory') + if "chunk_memory" in matching_job_params: + matching_job_params.pop("chunk_memory") matching_job_params["chunk_duration"] = "100ms" From 1c7c8020147e24997e3c34e374c76df8a72bc684 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 29 Aug 2023 15:25:58 +0200 Subject: [PATCH 17/22] WIP for cleaning --- .../sortingcomponents/clustering/random_projections.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index ac564bda9a..d9a317ca06 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -191,6 +191,8 @@ def main_function(cls, recording, peaks, params): ) cleaning_matching_params = params["job_kwargs"].copy() + if 'chunk_memory' in cleaning_matching_params: + cleaning_matching_params.pop('chunk_memory') cleaning_matching_params["chunk_duration"] = "100ms" cleaning_matching_params["n_jobs"] = 1 cleaning_matching_params["verbose"] = False From af4f1877aa800ff0277bd40a2aa83fc408b1ef08 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 29 Aug 2023 13:31:36 +0000 Subject: [PATCH 18/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sortingcomponents/clustering/random_projections.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index d9a317ca06..d82f9a7808 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -191,8 +191,8 @@ def main_function(cls, recording, peaks, params): ) cleaning_matching_params = params["job_kwargs"].copy() - if 'chunk_memory' in cleaning_matching_params: - cleaning_matching_params.pop('chunk_memory') + if "chunk_memory" in cleaning_matching_params: + cleaning_matching_params.pop("chunk_memory") cleaning_matching_params["chunk_duration"] = "100ms" cleaning_matching_params["n_jobs"] = 1 cleaning_matching_params["verbose"] = False From 8c2af8fcfa4c0ab4aa058e4778545b4cee64fa08 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 29 Aug 2023 18:09:23 +0200 Subject: [PATCH 19/22] WIP --- .../benchmark/benchmark_matching.py | 51 +++++++++++-------- 1 file changed, 30 insertions(+), 21 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py index 07c7db155c..8ce8efe25f 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py @@ -600,29 +600,38 @@ def plot_comparison_matching( else: ax = axs[j] comp1, comp2 = comp_per_method[method1], comp_per_method[method2] - for performance, color in zip(performance_names, colors): - perf1 = comp1.get_performance()[performance] - perf2 = comp2.get_performance()[performance] - ax.plot(perf2, perf1, ".", label=performance, color=color) - ax.plot([0, 1], [0, 1], "k--", alpha=0.5) - ax.set_ylim(ylim) - ax.set_xlim(ylim) - ax.spines[["right", "top"]].set_visible(False) - ax.set_aspect("equal") - - if j == 0: - ax.set_ylabel(f"{method1}") - else: - ax.set_yticks([]) - if i == num_methods - 1: - ax.set_xlabel(f"{method2}") + if i <= j: + for performance, color in zip(performance_names, colors): + perf1 = comp1.get_performance()[performance] + perf2 = comp2.get_performance()[performance] + ax.plot(perf2, perf1, ".", label=performance, color=color) + + ax.plot([0, 1], [0, 1], "k--", alpha=0.5) + ax.set_ylim(ylim) + ax.set_xlim(ylim) + ax.spines[["right", "top"]].set_visible(False) + ax.set_aspect("equal") + + if j == i: + ax.set_ylabel(f"{method1}") + else: + ax.set_yticks([]) + if i == j: + ax.set_xlabel(f"{method2}") + else: + ax.set_xticks([]) + if i == num_methods - 1 and j == num_methods - 1: + patches = [] + for color, name in zip(colors, performance_names): + patches.append(mpatches.Patch(color=color, label=name)) + ax.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc="upper left", borderaxespad=0.0) else: + ax.spines['bottom'].set_visible(False) + ax.spines['left'].set_visible(False) + ax.spines['top'].set_visible(False) + ax.spines['right'].set_visible(False) ax.set_xticks([]) - if i == num_methods - 1 and j == num_methods - 1: - patches = [] - for color, name in zip(colors, performance_names): - patches.append(mpatches.Patch(color=color, label=name)) - ax.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc="upper left", borderaxespad=0.0) + ax.set_yticks([]) plt.tight_layout(h_pad=0, w_pad=0) return fig, axs From 99e7acc8044d91773b2c77c67d51669dfe6b2fd2 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Wed, 13 Sep 2023 11:32:37 +0200 Subject: [PATCH 20/22] WIP --- src/spikeinterface/sorters/internal/spyking_circus2.py | 5 +++-- .../sortingcomponents/clustering/random_projections.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 8a7b353bd1..571096caf9 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -145,8 +145,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): matching_params.update({"noise_levels": noise_levels}) matching_job_params = job_kwargs.copy() - if "chunk_memory" in matching_job_params: - matching_job_params.pop("chunk_memory") + for value in ['chunk_size', 'chunk_memory', 'total_memory', 'chunk_duration']: + if value in matching_job_params: + matching_job_params.pop(value) matching_job_params["chunk_duration"] = "100ms" diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index d82f9a7808..025555440a 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -191,8 +191,9 @@ def main_function(cls, recording, peaks, params): ) cleaning_matching_params = params["job_kwargs"].copy() - if "chunk_memory" in cleaning_matching_params: - cleaning_matching_params.pop("chunk_memory") + for value in ['chunk_size', 'chunk_memory', 'total_memory', 'chunk_duration']: + if value in cleaning_matching_params: + cleaning_matching_params.pop(value) cleaning_matching_params["chunk_duration"] = "100ms" cleaning_matching_params["n_jobs"] = 1 cleaning_matching_params["verbose"] = False From cc792136cf213c4701a962206295dc7efaa718ad Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 13 Sep 2023 09:32:58 +0000 Subject: [PATCH 21/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sorters/internal/spyking_circus2.py | 2 +- .../sortingcomponents/benchmark/benchmark_matching.py | 8 ++++---- .../sortingcomponents/clustering/random_projections.py | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 571096caf9..db3d88f116 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -145,7 +145,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): matching_params.update({"noise_levels": noise_levels}) matching_job_params = job_kwargs.copy() - for value in ['chunk_size', 'chunk_memory', 'total_memory', 'chunk_duration']: + for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]: if value in matching_job_params: matching_job_params.pop(value) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py index 8ce8efe25f..50d64e1349 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py @@ -626,10 +626,10 @@ def plot_comparison_matching( patches.append(mpatches.Patch(color=color, label=name)) ax.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc="upper left", borderaxespad=0.0) else: - ax.spines['bottom'].set_visible(False) - ax.spines['left'].set_visible(False) - ax.spines['top'].set_visible(False) - ax.spines['right'].set_visible(False) + ax.spines["bottom"].set_visible(False) + ax.spines["left"].set_visible(False) + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) ax.set_xticks([]) ax.set_yticks([]) plt.tight_layout(h_pad=0, w_pad=0) diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 025555440a..5592b23c8d 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -191,7 +191,7 @@ def main_function(cls, recording, peaks, params): ) cleaning_matching_params = params["job_kwargs"].copy() - for value in ['chunk_size', 'chunk_memory', 'total_memory', 'chunk_duration']: + for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]: if value in cleaning_matching_params: cleaning_matching_params.pop(value) cleaning_matching_params["chunk_duration"] = "100ms" From dda78037d9570a529392af35055d343fc6c56022 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Wed, 13 Sep 2023 13:26:01 +0200 Subject: [PATCH 22/22] Adding unit_ids --- .../sortingcomponents/clustering/random_projections.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 5592b23c8d..be8ecd6702 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -177,7 +177,8 @@ def main_function(cls, recording, peaks, params): mode = "folder" sorting_folder = tmp_folder / "sorting" - sorting = NumpySorting.from_times_labels(spikes["sample_index"], spikes["unit_index"], fs) + unit_ids = np.arange(len(np.unique(spikes["unit_index"]))) + sorting = NumpySorting(spikes, fs, unit_ids=unit_ids) sorting = sorting.save(folder=sorting_folder) we = extract_waveforms( recording,