From fb8fc4b477889aa55770b3553a14171678ea0ae4 Mon Sep 17 00:00:00 2001 From: yger Date: Wed, 12 Apr 2023 17:47:16 +0200 Subject: [PATCH 01/17] WIP --- .../sortingcomponents/matching/wobble.py | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/matching/wobble.py b/src/spikeinterface/sortingcomponents/matching/wobble.py index 1ba4bddc9a..ea5062f5f6 100644 --- a/src/spikeinterface/sortingcomponents/matching/wobble.py +++ b/src/spikeinterface/sortingcomponents/matching/wobble.py @@ -311,11 +311,11 @@ def initialize_and_check_kwargs(cls, recording, kwargs): Updated Keyword arguments. """ d = cls.default_params.copy() - required_kwargs_keys = ['nbefore', 'nafter', 'templates'] - for required_key in required_kwargs_keys: - assert required_key in kwargs, f"`{required_key}` is a required key in the kwargs" - parameters = kwargs.get('parameters', {}) - templates = kwargs['templates'] + d.update(kwargs) + parameters = d.get('parameters', {}) + templates = d['waveform_extractor'].get_all_templates() + d['nbefore'] = d['waveform_extractor'].nbefore + d['nafter'] = d['waveform_extractor'].nafter templates = templates.astype(np.float32, casting='safe') # Aggregate useful parameters/variables for handy access in downstream functions @@ -336,11 +336,10 @@ def initialize_and_check_kwargs(cls, recording, kwargs): norm_squared=norm_squared) # Pack initial data into kwargs - kwargs['params'] = params - kwargs['template_meta'] = template_meta - kwargs['sparsity'] = sparsity - kwargs['template_data'] = template_data - d.update(kwargs) + d['params'] = params + d['template_meta'] = template_meta + d['sparsity'] = sparsity + d['template_data'] = template_data return d @@ -348,6 +347,8 @@ def initialize_and_check_kwargs(cls, recording, kwargs): def serialize_method_kwargs(cls, kwargs): # This function does nothing without a waveform extractor -- candidate for refactor kwargs = dict(kwargs) + # remove waveform_extractor + kwargs.pop('waveform_extractor') return kwargs @classmethod From 72679e5cdf01bb7f5344d38435a17cd6b78358d1 Mon Sep 17 00:00:00 2001 From: yger Date: Wed, 12 Apr 2023 18:10:52 +0200 Subject: [PATCH 02/17] Begin to factorize the sparsity for all matching methods --- src/spikeinterface/core/sparsity.py | 21 ++++++++++++++++++- .../sortingcomponents/matching/wobble.py | 2 +- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 64de81b00a..502d87e9dd 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -10,6 +10,8 @@ * "radius": radius around the best channel. Use the 'radius_um' argument to specify the radius in um * "snr": threshold based on template signal-to-noise ratio. Use the 'threshold' argument to specify the SNR threshold. + * "ptp": threshold based on the peak-to-peak values on every channels. Use the 'threshold' argument + to specify the ptp threshold * "energy": threshold based on the expected energy that should be present on the channels, given their noise levels. Use the 'threshold' argument to specify the SNR threshold * "by_property": sparsity is given by a property of the recording and sorting(e.g. 'group'). @@ -74,7 +76,7 @@ class ChannelSparsity: >>> sparsity = ChannelSparsity.from_snr(we, threshold, peak_sign='neg') Using a template energy threshold: - >>> sparsity = ChannelSparsity.from_energy(we, threshold, peak_sign='neg') + >>> sparsity = ChannelSparsity.from_energy(we, threshold) Using a recording/sorting property (e.g. 'group'): @@ -203,6 +205,20 @@ def from_snr(cls, we, threshold, peak_sign='neg'): mask[unit_ind, chan_inds] = True return cls(mask, we.unit_ids, we.channel_ids) + @classmethod + def from_ptp(cls, we, threshold): + """ + Construct sparsity from a thresholds based on template peak-to-peak values. + Use the 'threshold' argument to specify the SNR threshold. + """ + + mask = np.zeros((we.unit_ids.size, we.channel_ids.size), dtype='bool') + templates_ptps = np.ptp(we.get_all_templates(), axis=1) + for unit_ind, unit_id in enumerate(we.unit_ids): + chan_inds = np.nonzero(templates_ptps[unit_ind] >= threshold) + mask[unit_ind, chan_inds] = True + return cls(mask, we.unit_ids, we.channel_ids) + @classmethod def from_energy(cls, we, threshold): """ @@ -283,6 +299,9 @@ def compute_sparsity( elif method == "energy": assert threshold is not None, "For the 'energy' method, 'threshold' needs to be given" sparsity = ChannelSparsity.from_energy(waveform_extractor, threshold) + elif method == "ptp": + assert threshold is not None, "For the 'ptp' method, 'threshold' needs to be given" + sparsity = ChannelSparsity.from_ptp(waveform_extractor, threshold) elif method == "by_property": assert by_property is not None, "For the 'by_property' method, 'by_property' needs to be given" sparsity = ChannelSparsity.from_property(waveform_extractor, by_property) diff --git a/src/spikeinterface/sortingcomponents/matching/wobble.py b/src/spikeinterface/sortingcomponents/matching/wobble.py index ea5062f5f6..0cf174e24e 100644 --- a/src/spikeinterface/sortingcomponents/matching/wobble.py +++ b/src/spikeinterface/sortingcomponents/matching/wobble.py @@ -23,7 +23,7 @@ class WobbleParameters: underlying units of voltage trace and templates. approx_rank : int Rank of the compressed template matrices. - visibility_thresold : float + visibility_threshold : float Minimum peak amplitude to determine channel sparsity for a given unit. Units depend on the underlying units of voltage trace and templates. verbose : bool From 2a0cba008d04f0fd49f4ab019d4279778d9faa13 Mon Sep 17 00:00:00 2001 From: yger Date: Wed, 12 Apr 2023 22:49:10 +0200 Subject: [PATCH 03/17] Adding the sparsity to all engines. Dealing with wobble soon --- .../sortingcomponents/matching/circus.py | 115 +++++------------- .../sortingcomponents/matching/main.py | 1 - .../sortingcomponents/matching/naive.py | 43 ++++--- .../sortingcomponents/matching/tdc.py | 13 +- 4 files changed, 61 insertions(+), 111 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index 54511d86d6..98513c4974 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -15,7 +15,7 @@ except ImportError: HAVE_SKLEARN = False -from spikeinterface.core import get_noise_levels, get_random_data_chunks +from spikeinterface.core import get_noise_levels, get_random_data_chunks, compute_sparsity from spikeinterface.sortingcomponents.peak_detection import DetectPeakByChannel potrs, = scipy.linalg.get_lapack_funcs(('potrs',), dtype=np.float32) @@ -153,11 +153,6 @@ class CircusOMPPeeler(BaseTemplateMatchingEngine): (Minimal, Maximal) amplitudes allowed for every template omp_min_sps: float Stopping criteria of the OMP algorithm, in percentage of the norm - sparsify_threshold: float - Templates are sparsified in order to keep only the channels necessary - to explain. ptp limit for considering a channel as silent - smoothing_factor: float - Templates are smoothed via Spline Interpolation noise_levels: array The noise levels, for every channels. If None, they will be automatically computed @@ -165,7 +160,6 @@ class CircusOMPPeeler(BaseTemplateMatchingEngine): """ _default_params = { - 'sparsify_threshold': 1, 'amplitudes' : [0.6, 2], 'omp_min_sps' : 0.1, 'waveform_extractor': None, @@ -173,32 +167,11 @@ class CircusOMPPeeler(BaseTemplateMatchingEngine): 'overlaps' : None, 'norms' : None, 'random_chunk_kwargs': {}, - 'noise_levels': None, - 'smoothing_factor' : 0.25, - 'ignored_ids' : [] + 'noise_levels': None, + 'ignored_ids' : [], + 'sparsity' : {'method' : 'energy', 'threshold' : 1} } - @classmethod - def _sparsify_template(cls, template, sparsify_threshold): - - is_silent = template.ptp(0) < sparsify_threshold - template[:, is_silent] = 0 - active_channels, = np.where(np.logical_not(is_silent)) - - return template, active_channels - - @classmethod - def _regularize_template(cls, template, smoothing_factor=0.25): - - nb_channels = template.shape[1] - nb_timesteps = template.shape[0] - xaxis = np.arange(nb_timesteps) - for i in range(nb_channels): - z = scipy.interpolate.UnivariateSpline(xaxis, template[:, i]) - z.set_smoothing_factor(smoothing_factor) - template[:, i] = z(xaxis) - return template - @classmethod def _prepare_templates(cls, d): @@ -209,20 +182,22 @@ def _prepare_templates(cls, d): templates = waveform_extractor.get_all_templates(mode='median').copy() - d['sparsities'] = {} + if waveform_extractor.is_sparse(): + sparsity = waveform_extractor.sparsity + else: + sparsity = compute_sparsity(waveform_extractor, **d['sparsity']) + d['templates'] = {} + d['sparsities'] = {} d['norms'] = np.zeros(num_templates, dtype=np.float32) - for count, unit_id in enumerate(waveform_extractor.sorting.unit_ids): + for unit_ind, unit_id in enumerate(waveform_extractor.sorting.unit_ids): - if d['smoothing_factor'] > 0: - template = cls._regularize_template(templates[count], d['smoothing_factor']) - else: - template = templates[count] - template, active_channels = cls._sparsify_template(template, d['sparsify_threshold']) - d['sparsities'][count] = active_channels - d['norms'][count] = np.linalg.norm(template) - d['templates'][count] = template[:, active_channels]/d['norms'][count] + template = templates[unit_ind] + active_channels, = np.nonzero(sparsity.mask[unit_ind]) + d['sparsities'][unit_ind] = active_channels + d['norms'][unit_ind] = np.linalg.norm(template) + d['templates'][unit_ind] = template[:, active_channels]/d['norms'][unit_ind] return d @@ -303,7 +278,6 @@ def initialize_and_check_kwargs(cls, recording, kwargs): norms = d['norms'] sparsities = d['sparsities'] - nb_active_channels = np.array([len(sparsities[i]) for i in range(d['num_templates'])]) d['stop_criteria'] = omp_min_sps * np.sqrt(d['noise_levels'].sum() * d['num_samples']) return d @@ -522,9 +496,6 @@ class CircusPeeler(BaseTemplateMatchingEngine): Maximal amplitude allowed for every template min_amplitude: float Minimal amplitude allowed for every template - sparsify_threshold: float - Templates are sparsified in order to keep only the channels necessary - to explain a given fraction of the total norm use_sparse_matrix_threshold: float If density of the templates is below a given threshold, sparse matrix are used (memory efficient) @@ -542,44 +513,14 @@ class CircusPeeler(BaseTemplateMatchingEngine): 'detect_threshold': 5, 'noise_levels': None, 'random_chunk_kwargs': {}, - 'sparsify_threshold': 0.99, 'max_amplitude' : 1.5, 'min_amplitude' : 0.5, 'use_sparse_matrix_threshold' : 0.25, 'progess_bar_steps' : False, 'waveform_extractor': None, - 'smoothing_factor' : 0.25 + 'sparsity' : {'method' : 'energy', 'threshold' : 1} } - @classmethod - def _sparsify_template(cls, template, sparsify_threshold, noise_levels): - - is_silent = template.std(0) < 0.1*noise_levels - - template[:, is_silent] = 0 - - channel_norms = np.linalg.norm(template, axis=0)**2 - total_norm = np.linalg.norm(template)**2 - - idx = np.argsort(channel_norms)[::-1] - explained_norms = np.cumsum(channel_norms[idx]/total_norm) - channel = np.searchsorted(explained_norms, sparsify_threshold) - active_channels = np.sort(idx[:channel]) - template[:, idx[channel:]] = 0 - return template, active_channels - - @classmethod - def _regularize_template(cls, template, smoothing_factor=0.25): - - nb_channels = template.shape[1] - nb_timesteps = template.shape[0] - xaxis = np.arange(nb_timesteps) - for i in range(nb_channels): - z = scipy.interpolate.UnivariateSpline(xaxis, template[:, i]) - z.set_smoothing_factor(smoothing_factor) - template[:, i] = z(xaxis) - return template - @classmethod def _prepare_templates(cls, d): @@ -593,18 +534,18 @@ def _prepare_templates(cls, d): d['norms'] = np.zeros(num_templates, dtype=np.float32) - all_units = list(d['waveform_extractor'].sorting.unit_ids) - templates = waveform_extractor.get_all_templates(mode='median').copy() - - for count, unit_id in enumerate(all_units): - - if d['smoothing_factor'] > 0: - templates[count] = cls._regularize_template(templates[count], d['smoothing_factor']) - templates[count], _ = cls._sparsify_template(templates[count], d['sparsify_threshold'], d['noise_levels']) - d['norms'][count] = np.linalg.norm(templates[count]) - templates[count] /= d['norms'][count] + if waveform_extractor.is_sparse(): + sparsity = waveform_extractor.sparsity + else: + sparsity = compute_sparsity(waveform_extractor, **d['sparsity']) + + for unit_ind, unit_id in enumerate(d['waveform_extractor'].sorting.unit_ids): + active_channels, = np.nonzero(sparsity.mask[unit_ind]) + templates[unit_ind, ~active_channels] = 0 + d['norms'][unit_ind] = np.linalg.norm(templates[unit_ind]) + templates[unit_ind] /= d['norms'][unit_ind] templates = templates.reshape(num_templates, -1) @@ -736,7 +677,7 @@ def initialize_and_check_kwargs(cls, recording, kwargs): #assert isinstance(d['waveform_extractor'], WaveformExtractor) - for v in ['sparsify_threshold', 'use_sparse_matrix_threshold']: + for v in ['use_sparse_matrix_threshold']: assert (d[v] >= 0) and (d[v] <= 1), f'{v} should be in [0, 1]' d['num_channels'] = d['waveform_extractor'].recording.get_num_channels() diff --git a/src/spikeinterface/sortingcomponents/matching/main.py b/src/spikeinterface/sortingcomponents/matching/main.py index 0cc4984147..cf61d24541 100644 --- a/src/spikeinterface/sortingcomponents/matching/main.py +++ b/src/spikeinterface/sortingcomponents/matching/main.py @@ -6,7 +6,6 @@ from spikeinterface.core import get_chunk_with_margin - def find_spikes_from_templates(recording, method='naive', method_kwargs={}, extra_outputs=False, **job_kwargs): """Find spike from a recording from given templates. diff --git a/src/spikeinterface/sortingcomponents/matching/naive.py b/src/spikeinterface/sortingcomponents/matching/naive.py index 2d89606760..8a535c47f0 100644 --- a/src/spikeinterface/sortingcomponents/matching/naive.py +++ b/src/spikeinterface/sortingcomponents/matching/naive.py @@ -2,7 +2,7 @@ import numpy as np from spikeinterface.core import WaveformExtractor -from spikeinterface.core import get_noise_levels, get_channel_distances, get_chunk_with_margin, get_random_data_chunks +from spikeinterface.core import get_noise_levels, get_channel_distances, get_chunk_with_margin, get_random_data_chunks, compute_sparsity from spikeinterface.postprocessing import (get_template_channel_sparsity, get_template_extremum_channel) from spikeinterface.sortingcomponents.peak_detection import DetectPeakLocallyExclusive @@ -29,8 +29,9 @@ class NaiveMatching(BaseTemplateMatchingEngine): 'exclude_sweep_ms': 0.1, 'detect_threshold': 5, 'noise_levels': None, - 'local_radius_um': 100, + 'local_radius_um': 75, 'random_chunk_kwargs': {}, + 'sparsity' : {'method' : 'snr'} } @@ -46,7 +47,21 @@ def initialize_and_check_kwargs(cls, recording, kwargs): if d['noise_levels'] is None: d['noise_levels'] = get_noise_levels(recording, **d['random_chunk_kwargs']) - d['abs_threholds'] = d['noise_levels'] * d['detect_threshold'] + d['abs_thresholds'] = d['noise_levels'] * d['detect_threshold'] + + if we.is_sparse(): + sparsity = we.sparsity + else: + d['sparsity'].update({'peak_sign' : d['peak_sign'], 'threshold' : d['detect_threshold']}) + sparsity = compute_sparsity(we, **d['sparsity']) + + templates = we.get_all_templates() + + for unit_ind, unit_id in enumerate(we.sorting.unit_ids): + active_channels, = np.nonzero(sparsity.mask[unit_ind]) + templates[unit_ind, ~active_channels] = 0 + + d['templates'] = templates channel_distance = get_channel_distances(recording) d['neighbours_mask'] = channel_distance < d['local_radius_um'] @@ -66,31 +81,19 @@ def get_margin(cls, recording, kwargs): @classmethod def serialize_method_kwargs(cls, kwargs): kwargs = dict(kwargs) - - waveform_extractor = kwargs['waveform_extractor'] - kwargs['waveform_extractor'] = str(waveform_extractor.folder) - + # remove waveform_extractor + kwargs.pop('waveform_extractor') return kwargs @classmethod - def unserialize_in_worker(cls, kwargs): - - we = kwargs['waveform_extractor'] - if isinstance(we, str): - we = WaveformExtractor.load(we) - kwargs['waveform_extractor'] = we - - templates = we.get_all_templates(mode='average') - - kwargs['templates'] = templates - + def unserialize_in_worker(cls, kwargs): return kwargs @classmethod def main_function(cls, traces, method_kwargs): peak_sign = method_kwargs['peak_sign'] - abs_threholds = method_kwargs['abs_threholds'] + abs_thresholds = method_kwargs['abs_thresholds'] exclude_sweep_size = method_kwargs['exclude_sweep_size'] neighbours_mask = method_kwargs['neighbours_mask'] templates = method_kwargs['templates'] @@ -104,7 +107,7 @@ def main_function(cls, traces, method_kwargs): peak_traces = traces[margin:-margin, :] else: peak_traces = traces - peak_sample_ind, peak_chan_ind = DetectPeakLocallyExclusive.detect_peaks(peak_traces, peak_sign, abs_threholds, exclude_sweep_size, neighbours_mask) + peak_sample_ind, peak_chan_ind = DetectPeakLocallyExclusive.detect_peaks(peak_traces, peak_sign, abs_thresholds, exclude_sweep_size, neighbours_mask) peak_sample_ind += margin diff --git a/src/spikeinterface/sortingcomponents/matching/tdc.py b/src/spikeinterface/sortingcomponents/matching/tdc.py index 915264c26e..b7fa9a87d3 100644 --- a/src/spikeinterface/sortingcomponents/matching/tdc.py +++ b/src/spikeinterface/sortingcomponents/matching/tdc.py @@ -37,13 +37,15 @@ class TridesclousPeeler(BaseTemplateMatchingEngine): 'peak_shift_ms': 0.2, 'detect_threshold': 5, 'noise_levels': None, - 'local_radius_um': 100, + 'local_radius_um': 75, + 'random_chunk_kwargs': {}, 'num_closest' : 5, 'sample_shift': 3, 'ms_before': 0.8, 'ms_after': 1.2, 'num_peeler_loop': 2, 'num_template_try' : 1, + 'sparsity' : {'method' : 'snr'} } @classmethod @@ -89,14 +91,19 @@ def initialize_and_check_kwargs(cls, recording, kwargs): if d['noise_levels'] is None: print('TridesclousPeeler : noise should be computed outside') - d['noise_levels'] = get_noise_levels(recording) + d['noise_levels'] = get_noise_levels(recording, **d['random_chunk_kwargs']) d['abs_threholds'] = d['noise_levels'] * d['detect_threshold'] channel_distance = get_channel_distances(recording) d['neighbours_mask'] = channel_distance < d['local_radius_um'] - sparsity = compute_sparsity(we, method='snr', peak_sign=d['peak_sign'], threshold=d['detect_threshold']) + if we.is_sparse(): + sparsity = we.sparsity + else: + d['sparsity'].update({'peak_sign' : d['peak_sign'], 'threshold' : d['detect_threshold']}) + sparsity = compute_sparsity(we, **d['sparsity']) + template_sparsity_inds = sparsity.unit_id_to_channel_indices template_sparsity = np.zeros((unit_ids.size, channel_ids.size), dtype='bool') for unit_index, unit_id in enumerate(unit_ids): From 19cafe6498d0612cd6e1ede9c753f044b1b00127 Mon Sep 17 00:00:00 2001 From: yger Date: Thu, 13 Apr 2023 16:53:04 +0200 Subject: [PATCH 04/17] Refactor --- .../benchmark/benchmark_matching.py | 2 +- .../sortingcomponents/matching/circus.py | 244 +++++------------- .../sortingcomponents/matching/main.py | 60 ++++- .../sortingcomponents/matching/naive.py | 38 +-- .../sortingcomponents/matching/tdc.py | 49 +--- .../sortingcomponents/matching/wobble.py | 4 +- 6 files changed, 133 insertions(+), 264 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py index 0418251598..4e31fe60e9 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py @@ -47,7 +47,7 @@ def __del__(self): def run(self): t_start = time.time() - self.spikes = find_spikes_from_templates(self.recording, method=self.method, method_kwargs=self.method_kwargs, **self.job_kwargs) + self.spikes = find_spikes_from_templates(self.recording, self.we, method=self.method, method_kwargs=self.method_kwargs, **self.job_kwargs) self.run_time = time.time() - t_start self.sorting = NumpySorting.from_times_labels(self.spikes['sample_ind'], self.spikes['cluster_ind'], self.sampling_rate) self.comp = CollisionGTComparison(self.gt_sorting, self.sorting, exhaustive_gt=self.exhaustive_gt) diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index 98513c4974..f91960afb8 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -4,6 +4,7 @@ import warnings import scipy.spatial +from scipy.sparse import csr_matrix from tqdm import tqdm import scipy @@ -132,6 +133,33 @@ def _freq_domain_conv(in1, in2, axes, shape, cache, calc_fast_len=True): return ret +def compute_overlaps(templates, num_samples, num_channels): + + num_templates = len(templates) + + size = 2 * num_samples - 1 + + all_delays = list(range(0, num_samples+1)) + overlaps = {} + + for delay in all_delays: + source = templates[:, :delay, :].reshape(num_templates, -1) + target = templates[:, num_samples-delay:, :].reshape(num_templates, -1) + overlaps[delay] = scipy.sparse.csr_matrix(source.dot(target.T)) + + if delay < num_samples: + overlaps[size - delay] = overlaps[delay].T.tocsr() + + new_overlaps = [] + + for i in range(num_templates): + data = [overlaps[j][i, :].T for j in range(size)] + data = scipy.sparse.hstack(data) + new_overlaps += [data] + + return new_overlaps + + class CircusOMPPeeler(BaseTemplateMatchingEngine): """ Orthogonal Matching Pursuit inspired from Spyking Circus sorter @@ -161,82 +189,40 @@ class CircusOMPPeeler(BaseTemplateMatchingEngine): _default_params = { 'amplitudes' : [0.6, 2], - 'omp_min_sps' : 0.1, - 'waveform_extractor': None, - 'templates' : None, + 'omp_min_sps' : 0.2, 'overlaps' : None, 'norms' : None, 'random_chunk_kwargs': {}, 'noise_levels': None, - 'ignored_ids' : [], - 'sparsity' : {'method' : 'energy', 'threshold' : 1} + 'ignored_ids' : [] } @classmethod def _prepare_templates(cls, d): - waveform_extractor = d['waveform_extractor'] num_samples = d['num_samples'] num_channels = d['num_channels'] - num_templates = len(d['waveform_extractor'].sorting.unit_ids) - - templates = waveform_extractor.get_all_templates(mode='median').copy() - - if waveform_extractor.is_sparse(): - sparsity = waveform_extractor.sparsity - else: - sparsity = compute_sparsity(waveform_extractor, **d['sparsity']) - - d['templates'] = {} - d['sparsities'] = {} + templates = d['templates'] + num_templates = len(templates) + d['circus_templates'] = d['templates'].copy() d['norms'] = np.zeros(num_templates, dtype=np.float32) - for unit_ind, unit_id in enumerate(waveform_extractor.sorting.unit_ids): - - template = templates[unit_ind] - active_channels, = np.nonzero(sparsity.mask[unit_ind]) - d['sparsities'][unit_ind] = active_channels - d['norms'][unit_ind] = np.linalg.norm(template) - d['templates'][unit_ind] = template[:, active_channels]/d['norms'][unit_ind] + for unit_ind in range(num_templates): + d['norms'][unit_ind] = np.linalg.norm(d['circus_templates'][unit_ind]) + d['circus_templates'][unit_ind] /= d['norms'][unit_ind] return d @classmethod - def _prepare_overlaps(cls, d): - - templates = d['templates'] - num_samples = d['num_samples'] - num_channels = d['num_channels'] - num_templates = d['num_templates'] - sparsities = d['sparsities'] - - dense_templates = np.zeros((num_templates, num_samples, num_channels), dtype=np.float32) - for i in range(num_templates): - dense_templates[i, :, sparsities[i]] = templates[i].T - - size = 2 * num_samples - 1 - - all_delays = list(range(0, num_samples+1)) - - overlaps = {} + def _compress_templates(cls, d): - for delay in all_delays: - source = dense_templates[:, :delay, :].reshape(num_templates, -1) - target = dense_templates[:, num_samples-delay:, :].reshape(num_templates, -1) - - overlaps[delay] = scipy.sparse.csr_matrix(source.dot(target.T)) - - if delay < num_samples: - overlaps[size - delay + 1] = overlaps[delay].T.tocsr() - - new_overlaps = [] - - for i in range(num_templates): - data = [overlaps[j][i, :].T for j in range(size)] - data = scipy.sparse.hstack(data) - new_overlaps += [data] - - d['overlaps'] = new_overlaps + templates = d.pop('circus_templates') + num_templates = len(templates) + d['circus_templates'] = {} + + for unit_ind in range(num_templates): + active_channels = d['sparsity_mask'][unit_ind] + d['circus_templates'][unit_ind] = templates[unit_ind][:, active_channels] return d @@ -246,49 +232,32 @@ def initialize_and_check_kwargs(cls, recording, kwargs): d = cls._default_params.copy() d.update(kwargs) - #assert isinstance(d['waveform_extractor'], WaveformExtractor) - for v in ['omp_min_sps']: assert (d[v] >= 0) and (d[v] <= 1), f'{v} should be in [0, 1]' d['num_channels'] = d['waveform_extractor'].recording.get_num_channels() d['num_samples'] = d['waveform_extractor'].nsamples - d['nbefore'] = d['waveform_extractor'].nbefore - d['nafter'] = d['waveform_extractor'].nafter + d['num_templates'] = len(d['templates']) d['sampling_frequency'] = d['waveform_extractor'].recording.get_sampling_frequency() if d['noise_levels'] is None: print('CircusOMPPeeler : noise should be computed outside') d['noise_levels'] = get_noise_levels(recording, **d['random_chunk_kwargs'], return_scaled=False) - if d['templates'] is None: - d = cls._prepare_templates(d) - else: - for key in ['norms', 'sparsities']: - assert d[key] is not None, "If templates are provided, %d should also be there" %key - - d['num_templates'] = len(d['templates']) + d = cls._prepare_templates(d) if d['overlaps'] is None: - d = cls._prepare_overlaps(d) + d['overlaps'] = compute_overlaps(d['circus_templates'], d['num_samples'], d['num_channels']) + + d = cls._compress_templates(d) d['ignored_ids'] = np.array(d['ignored_ids']) omp_min_sps = d['omp_min_sps'] - norms = d['norms'] - sparsities = d['sparsities'] - d['stop_criteria'] = omp_min_sps * np.sqrt(d['noise_levels'].sum() * d['num_samples']) return d - @classmethod - def serialize_method_kwargs(cls, kwargs): - kwargs = dict(kwargs) - # remove waveform_extractor - kwargs.pop('waveform_extractor') - return kwargs - @classmethod def unserialize_in_worker(cls, kwargs): return kwargs @@ -300,7 +269,7 @@ def get_margin(cls, recording, kwargs): @classmethod def main_function(cls, traces, d): - templates = d['templates'] + templates = d['circus_templates'] num_templates = d['num_templates'] num_channels = d['num_channels'] num_samples = d['num_samples'] @@ -312,7 +281,7 @@ def main_function(cls, traces, d): num_samples = d['nafter'] + d['nbefore'] neighbor_window = num_samples - 1 min_amplitude, max_amplitude = d['amplitudes'] - sparsities = d['sparsities'] + sparsities = d['sparsity_mask'] ignored_ids = d['ignored_ids'] stop_criteria = d['stop_criteria'] @@ -499,8 +468,6 @@ class CircusPeeler(BaseTemplateMatchingEngine): use_sparse_matrix_threshold: float If density of the templates is below a given threshold, sparse matrix are used (memory efficient) - progress_bar_steps: bool - In order to display or not steps from the algorithm ----- @@ -516,93 +483,42 @@ class CircusPeeler(BaseTemplateMatchingEngine): 'max_amplitude' : 1.5, 'min_amplitude' : 0.5, 'use_sparse_matrix_threshold' : 0.25, - 'progess_bar_steps' : False, - 'waveform_extractor': None, - 'sparsity' : {'method' : 'energy', 'threshold' : 1} } @classmethod def _prepare_templates(cls, d): - waveform_extractor = d['waveform_extractor'] num_samples = d['num_samples'] num_channels = d['num_channels'] num_templates = d['num_templates'] max_amplitude = d['max_amplitude'] min_amplitude = d['min_amplitude'] use_sparse_matrix_threshold = d['use_sparse_matrix_threshold'] - + d['circus_templates'] = d['templates'].copy() d['norms'] = np.zeros(num_templates, dtype=np.float32) - templates = waveform_extractor.get_all_templates(mode='median').copy() - - if waveform_extractor.is_sparse(): - sparsity = waveform_extractor.sparsity - else: - sparsity = compute_sparsity(waveform_extractor, **d['sparsity']) + for unit_ind in range(num_templates): + d['norms'][unit_ind] = np.linalg.norm(d['circus_templates'][unit_ind]) + d['circus_templates'][unit_ind] /= d['norms'][unit_ind] - for unit_ind, unit_id in enumerate(d['waveform_extractor'].sorting.unit_ids): - active_channels, = np.nonzero(sparsity.mask[unit_ind]) - templates[unit_ind, ~active_channels] = 0 - d['norms'][unit_ind] = np.linalg.norm(templates[unit_ind]) - templates[unit_ind] /= d['norms'][unit_ind] - - templates = templates.reshape(num_templates, -1) - - nnz = np.sum(templates != 0)/(num_templates * num_samples * num_channels) - if nnz <= use_sparse_matrix_threshold: - templates = scipy.sparse.csr_matrix(templates) - print(f'Templates are automatically sparsified (sparsity level is {nnz})') - d['is_dense'] = False - else: - d['is_dense'] = True - - d['templates'] = templates - return d @classmethod - def _prepare_overlaps(cls, d): - - templates = d['templates'] + def _compress_templates(cls, d): + circus_templates = d.pop('circus_templates') + num_templates = len(circus_templates) num_samples = d['num_samples'] num_channels = d['num_channels'] - num_templates = d['num_templates'] - is_dense = d['is_dense'] + circus_templates = circus_templates.reshape(num_templates, -1) - if not is_dense: - dense_templates = templates.toarray() + nnz = np.sum(circus_templates != 0)/(num_templates * num_samples * num_channels) + if nnz <= d['use_sparse_matrix_threshold']: + circus_templates = scipy.sparse.csr_matrix(circus_templates) + d['is_dense'] = False else: - dense_templates = templates - - dense_templates = dense_templates.reshape(num_templates, num_samples, num_channels) - - size = 2 * num_samples - 1 - - all_delays = list(range(0, num_samples+1)) - if d['progess_bar_steps']: - all_delays = tqdm(all_delays, desc='[1] compute overlaps') - - overlaps = {} - - for delay in all_delays: - source = dense_templates[:, :delay, :].reshape(num_templates, -1) - target = dense_templates[:, num_samples-delay:, :].reshape(num_templates, -1) - - overlaps[delay] = scipy.sparse.csr_matrix(source.dot(target.T)) - - if delay < num_samples: - overlaps[size - delay] = overlaps[delay].T.tocsr() - - new_overlaps = [] - - for i in range(num_templates): - data = [overlaps[j][i, :].T for j in range(size)] - data = scipy.sparse.hstack(data) - new_overlaps += [data] - - d['overlaps'] = new_overlaps + d['is_dense'] = True + d['circus_templates'] = circus_templates return d @classmethod @@ -628,15 +544,13 @@ def _cost_function_mcc(cls, bounds, good, bad, delta_amplitude, alpha): def _optimize_amplitudes(cls, noise_snippets, d): waveform_extractor = d['waveform_extractor'] - templates = d['templates'] + templates = d['circus_templates'] num_templates = d['num_templates'] max_amplitude = d['max_amplitude'] min_amplitude = d['min_amplitude'] alpha = 0.5 norms = d['norms'] all_units = list(waveform_extractor.sorting.unit_ids) - if d['progess_bar_steps']: - all_units = tqdm(all_units, desc='[2] compute amplitudes') d['amplitudes'] = np.zeros((num_templates, 2), dtype=np.float32) noise = templates.dot(noise_snippets)/norms[:, np.newaxis] @@ -656,16 +570,6 @@ def _optimize_amplitudes(cls, noise_snippets, d): res = scipy.optimize.differential_evolution(cls._cost_function_mcc, bounds=cost_bounds, args=cost_kwargs) d['amplitudes'][count] = res.x - # import pylab as plt - # plt.hist(good, 100, alpha=0.5) - # plt.hist(bad, 100, alpha=0.5) - # plt.hist(noise[count], 100, alpha=0.5) - # ymin, ymax = plt.ylim() - # plt.plot([res.x[0], res.x[0]], [ymin, ymax], 'k--') - # plt.plot([res.x[1], res.x[1]], [ymin, ymax], 'k--') - # plt.savefig('test_%d.png' %count) - # plt.close() - return d @classmethod @@ -675,14 +579,12 @@ def initialize_and_check_kwargs(cls, recording, kwargs): d = cls._default_params.copy() d.update(kwargs) - #assert isinstance(d['waveform_extractor'], WaveformExtractor) - for v in ['use_sparse_matrix_threshold']: assert (d[v] >= 0) and (d[v] <= 1), f'{v} should be in [0, 1]' d['num_channels'] = d['waveform_extractor'].recording.get_num_channels() d['num_samples'] = d['waveform_extractor'].nsamples - d['num_templates'] = len(d['waveform_extractor'].sorting.unit_ids) + d['num_templates'] = len(d['templates']) if d['noise_levels'] is None: print('CircusPeeler : noise should be computed outside') @@ -691,7 +593,8 @@ def initialize_and_check_kwargs(cls, recording, kwargs): d['abs_threholds'] = d['noise_levels'] * d['detect_threshold'] d = cls._prepare_templates(d) - d = cls._prepare_overlaps(d) + d['overlaps'] = compute_overlaps(d['circus_templates'], d['num_samples'], d['num_channels']) + d = cls._compress_templates(d) d['exclude_sweep_size'] = int(d['exclude_sweep_ms'] * recording.get_sampling_frequency() / 1000.) @@ -714,13 +617,6 @@ def initialize_and_check_kwargs(cls, recording, kwargs): return d - @classmethod - def serialize_method_kwargs(cls, kwargs): - kwargs = dict(kwargs) - # remove waveform_extractor - kwargs.pop('waveform_extractor') - return kwargs - @classmethod def unserialize_in_worker(cls, kwargs): return kwargs @@ -735,7 +631,7 @@ def main_function(cls, traces, d): peak_sign = d['peak_sign'] abs_threholds = d['abs_threholds'] exclude_sweep_size = d['exclude_sweep_size'] - templates = d['templates'] + templates = d['circus_templates'] num_templates = d['num_templates'] num_channels = d['num_channels'] overlaps = d['overlaps'] diff --git a/src/spikeinterface/sortingcomponents/matching/main.py b/src/spikeinterface/sortingcomponents/matching/main.py index cf61d24541..673d866648 100644 --- a/src/spikeinterface/sortingcomponents/matching/main.py +++ b/src/spikeinterface/sortingcomponents/matching/main.py @@ -3,11 +3,12 @@ import numpy as np from spikeinterface.core.job_tools import ChunkRecordingExecutor, fix_job_kwargs -from spikeinterface.core import get_chunk_with_margin +from spikeinterface.core import get_chunk_with_margin, compute_sparsity, WaveformExtractor -def find_spikes_from_templates(recording, method='naive', method_kwargs={}, extra_outputs=False, - **job_kwargs): +def find_spikes_from_templates(recording, waveform_extractor, sparsity={'method' : 'ptp', 'threshold' : 1}, + templates=None, sparsity_mask=None, method='naive', + method_kwargs={}, extra_outputs=False, **job_kwargs): """Find spike from a recording from given templates. Parameters @@ -15,9 +16,16 @@ def find_spikes_from_templates(recording, method='naive', method_kwargs={}, extr recording: RecordingExtractor The recording extractor object waveform_extractor: WaveformExtractor - The waveform extractor + The waveform extractor to get the templates (if templates are not provided manually) + sparsity: dict or None + Parameters that should be given to sparsify the templates, if waveform_extractor + is not already sparse + templates: np.array + If provided, then the templates are used instead of the ones from the waveform_extractor + sparsity_mask: np.array, bool + If provided, the sparsity mask used for the provided templates method: str - Which method to use ('naive' | 'tridesclous' | 'circus') + Which method to use ('naive' | 'tridesclous' | 'circus' | 'circus-omp' | 'wobble') method_kwargs: dict, optional Keyword arguments for the chosen method extra_outputs: bool @@ -43,6 +51,10 @@ def find_spikes_from_templates(recording, method='naive', method_kwargs={}, extr method_class = matching_methods[method] + # initialize the templates + method_kwargs = method_class.initialize_and_sparsify_templates(method_kwargs, waveform_extractor, sparsity, + templates, sparsity_mask) + # initialize method_kwargs = method_class.initialize_and_check_kwargs(recording, method_kwargs) @@ -118,8 +130,38 @@ def _find_spikes_chunk(segment_index, start_frame, end_frame, worker_ctx): # generic class for template engine class BaseTemplateMatchingEngine: - default_params = {} + @classmethod + def initialize_and_sparsify_templates(cls, kwargs, waveform_extractor, sparsity, templates, sparsity_mask): + assert isinstance(waveform_extractor, WaveformExtractor) + kwargs.update({'nbefore' : waveform_extractor.nbefore, + 'nafter' : waveform_extractor.nafter, + 'sampling_frequency' : waveform_extractor.sampling_frequency}) + + num_channels = waveform_extractor.get_num_channels() + + if templates is not None: + kwargs['templates'] = templates.copy() + num_templates = len(templates) + if sparsity_mask is None: + kwargs['sparsity_mask'] = np.ones((num_templates, num_channels), dtype=bool) + else: + kwargs['templates'] = waveform_extractor.get_all_templates().copy() + num_templates = len(kwargs['templates']) + if waveform_extractor.is_sparse(): + kwargs['sparsity_mask'] = waveform_extractor.sparsity.mask + else: + if sparsity is not None: + kwargs['sparsity_mask'] = compute_sparsity(waveform_extractor, **sparsity).mask + else: + kwargs['sparsity_mask'] = np.ones((num_templates, num_channels), dtype=bool) + + for unit_ind in range(num_templates): + active_channels = kwargs['sparsity_mask'][unit_ind] + kwargs['templates'][unit_ind][:, ~active_channels] = 0 + + return kwargs + @classmethod def initialize_and_check_kwargs(cls, recording, kwargs): """This function runs before loops""" @@ -129,8 +171,10 @@ def initialize_and_check_kwargs(cls, recording, kwargs): @classmethod def serialize_method_kwargs(cls, kwargs): """This function serializes kwargs to distribute them to workers""" - # need to be implemented in subclass - raise NotImplementedError + kwargs = dict(kwargs) + # remove waveform_extractor + kwargs.pop('waveform_extractor') + return kwargs @classmethod def unserialize_in_worker(cls, recording, kwargs): diff --git a/src/spikeinterface/sortingcomponents/matching/naive.py b/src/spikeinterface/sortingcomponents/matching/naive.py index 8a535c47f0..ab0809606f 100644 --- a/src/spikeinterface/sortingcomponents/matching/naive.py +++ b/src/spikeinterface/sortingcomponents/matching/naive.py @@ -24,14 +24,12 @@ class NaiveMatching(BaseTemplateMatchingEngine): And also as an example how to deal with methods_kwargs, margin, intit, func, ... """ default_params = { - 'waveform_extractor': None, 'peak_sign': 'neg', 'exclude_sweep_ms': 0.1, 'detect_threshold': 5, 'noise_levels': None, 'local_radius_um': 75, - 'random_chunk_kwargs': {}, - 'sparsity' : {'method' : 'snr'} + 'random_chunk_kwargs': {} } @@ -40,36 +38,14 @@ def initialize_and_check_kwargs(cls, recording, kwargs): d = cls.default_params.copy() d.update(kwargs) - assert d['waveform_extractor'] is not None - - we = d['waveform_extractor'] - if d['noise_levels'] is None: d['noise_levels'] = get_noise_levels(recording, **d['random_chunk_kwargs']) d['abs_thresholds'] = d['noise_levels'] * d['detect_threshold'] - if we.is_sparse(): - sparsity = we.sparsity - else: - d['sparsity'].update({'peak_sign' : d['peak_sign'], 'threshold' : d['detect_threshold']}) - sparsity = compute_sparsity(we, **d['sparsity']) - - templates = we.get_all_templates() - - for unit_ind, unit_id in enumerate(we.sorting.unit_ids): - active_channels, = np.nonzero(sparsity.mask[unit_ind]) - templates[unit_ind, ~active_channels] = 0 - - d['templates'] = templates - channel_distance = get_channel_distances(recording) d['neighbours_mask'] = channel_distance < d['local_radius_um'] - - d['nbefore'] = we.nbefore - d['nafter'] = we.nafter - - d['exclude_sweep_size'] = int(d['exclude_sweep_ms'] * recording.get_sampling_frequency() / 1000.) + d['exclude_sweep_size'] = int(d['exclude_sweep_ms'] * d['sampling_frequency'] / 1000.) return d @@ -78,13 +54,6 @@ def get_margin(cls, recording, kwargs): margin = max(kwargs['nbefore'], kwargs['nafter']) return margin - @classmethod - def serialize_method_kwargs(cls, kwargs): - kwargs = dict(kwargs) - # remove waveform_extractor - kwargs.pop('waveform_extractor') - return kwargs - @classmethod def unserialize_in_worker(cls, kwargs): return kwargs @@ -97,10 +66,8 @@ def main_function(cls, traces, method_kwargs): exclude_sweep_size = method_kwargs['exclude_sweep_size'] neighbours_mask = method_kwargs['neighbours_mask'] templates = method_kwargs['templates'] - nbefore = method_kwargs['nbefore'] nafter = method_kwargs['nafter'] - margin = method_kwargs['margin'] if margin > 0: @@ -110,7 +77,6 @@ def main_function(cls, traces, method_kwargs): peak_sample_ind, peak_chan_ind = DetectPeakLocallyExclusive.detect_peaks(peak_traces, peak_sign, abs_thresholds, exclude_sweep_size, neighbours_mask) peak_sample_ind += margin - spikes = np.zeros(peak_sample_ind.size, dtype=spike_dtype) spikes['sample_ind'] = peak_sample_ind spikes['channel_ind'] = peak_chan_ind # TODO need to put the channel from template diff --git a/src/spikeinterface/sortingcomponents/matching/tdc.py b/src/spikeinterface/sortingcomponents/matching/tdc.py index b7fa9a87d3..76300360b9 100644 --- a/src/spikeinterface/sortingcomponents/matching/tdc.py +++ b/src/spikeinterface/sortingcomponents/matching/tdc.py @@ -1,7 +1,7 @@ import numpy as np import scipy from spikeinterface.core import (WaveformExtractor, get_noise_levels, get_channel_distances, - compute_sparsity, get_template_extremum_channel) + get_template_extremum_channel) from spikeinterface.sortingcomponents.peak_detection import DetectPeakLocallyExclusive @@ -32,7 +32,6 @@ class TridesclousPeeler(BaseTemplateMatchingEngine): spike collision when templates have high similarity. """ default_params = { - 'waveform_extractor': None, 'peak_sign': 'neg', 'peak_shift_ms': 0.2, 'detect_threshold': 5, @@ -44,8 +43,7 @@ class TridesclousPeeler(BaseTemplateMatchingEngine): 'ms_before': 0.8, 'ms_after': 1.2, 'num_peeler_loop': 2, - 'num_template_try' : 1, - 'sparsity' : {'method' : 'snr'} + 'num_template_try' : 1 } @classmethod @@ -56,23 +54,13 @@ def initialize_and_check_kwargs(cls, recording, kwargs): d = cls.default_params.copy() d.update(kwargs) - assert isinstance(d['waveform_extractor'], WaveformExtractor) - we = d['waveform_extractor'] + templates = d['templates'] unit_ids = we.unit_ids channel_ids = we.channel_ids sr = we.sampling_frequency - - # TODO load as sharedmem - templates = we.get_all_templates(mode='average') - d['templates'] = templates - - d['nbefore'] = we.nbefore - d['nafter'] = we.nafter - - nbefore_short = int(d['ms_before'] * sr / 1000.) nafter_short = int(d['ms_before'] * sr / 1000.) assert nbefore_short <= we.nbefore @@ -86,7 +74,6 @@ def initialize_and_check_kwargs(cls, recording, kwargs): templates_short = templates[:, slice(s0,s1), :].copy() d['templates_short'] = templates_short - d['peak_shift'] = int(d['peak_shift_ms'] / 1000 * sr) if d['noise_levels'] is None: @@ -97,20 +84,6 @@ def initialize_and_check_kwargs(cls, recording, kwargs): channel_distance = get_channel_distances(recording) d['neighbours_mask'] = channel_distance < d['local_radius_um'] - - if we.is_sparse(): - sparsity = we.sparsity - else: - d['sparsity'].update({'peak_sign' : d['peak_sign'], 'threshold' : d['detect_threshold']}) - sparsity = compute_sparsity(we, **d['sparsity']) - - template_sparsity_inds = sparsity.unit_id_to_channel_indices - template_sparsity = np.zeros((unit_ids.size, channel_ids.size), dtype='bool') - for unit_index, unit_id in enumerate(unit_ids): - chan_inds = template_sparsity_inds[unit_id] - template_sparsity[unit_index, chan_inds] = True - - d['template_sparsity'] = template_sparsity extremum_channel = get_template_extremum_channel(we, peak_sign=d['peak_sign'], outputs='index') # as numpy vector @@ -135,7 +108,7 @@ def initialize_and_check_kwargs(cls, recording, kwargs): closest_u = np.array(closest_u[:d['num_closest']]) # compute unitary discriminent vector - chans, = np.nonzero(d['template_sparsity'][unit_ind, :]) + chans, = np.nonzero(d['sparsity_mask'][unit_ind, :]) template_sparse = templates[unit_ind, :, :][:, chans] closest_vec = [] # against N closets @@ -165,14 +138,6 @@ def initialize_and_check_kwargs(cls, recording, kwargs): return d - @classmethod - def serialize_method_kwargs(cls, kwargs): - kwargs = dict(kwargs) - - # remove waveform_extractor - kwargs.pop('waveform_extractor') - return kwargs - @classmethod def unserialize_in_worker(cls, kwargs): return kwargs @@ -257,11 +222,11 @@ def _tdc_find_spikes(traces, d, level=0): # distances = np.sum(np.sum((templates[possible_clusters, :, :] - wf[None, : , :])**2, axis=1), axis=1) ## pure numpy with cluster+channel spasity - # union_channels, = np.nonzero(np.any(d['template_sparsity'][possible_clusters, :], axis=0)) + # union_channels, = np.nonzero(np.any(d['sparsity_mask'][possible_clusters, :], axis=0)) # distances = np.sum(np.sum((templates[possible_clusters][:, :, union_channels] - wf[: , union_channels][None, : :])**2, axis=1), axis=1) ## numba with cluster+channel spasity - union_channels = np.any(d['template_sparsity'][possible_clusters, :], axis=0) + union_channels = np.any(d['sparsity_mask'][possible_clusters, :], axis=0) # distances = numba_sparse_dist(wf, templates, union_channels, possible_clusters) distances = numba_sparse_dist(wf_short, templates_short, union_channels, possible_clusters) @@ -273,7 +238,7 @@ def _tdc_find_spikes(traces, d, level=0): for ind in np.argsort(distances)[:d['num_template_try']]: cluster_ind = possible_clusters[ind] - chan_sparsity = d['template_sparsity'][cluster_ind, :] + chan_sparsity = d['sparsity_mask'][cluster_ind, :] template_sparse = templates[cluster_ind, :, :][:, chan_sparsity] # find best shift diff --git a/src/spikeinterface/sortingcomponents/matching/wobble.py b/src/spikeinterface/sortingcomponents/matching/wobble.py index 0cf174e24e..cd95829024 100644 --- a/src/spikeinterface/sortingcomponents/matching/wobble.py +++ b/src/spikeinterface/sortingcomponents/matching/wobble.py @@ -313,9 +313,7 @@ def initialize_and_check_kwargs(cls, recording, kwargs): d = cls.default_params.copy() d.update(kwargs) parameters = d.get('parameters', {}) - templates = d['waveform_extractor'].get_all_templates() - d['nbefore'] = d['waveform_extractor'].nbefore - d['nafter'] = d['waveform_extractor'].nafter + templates = d['templates'] templates = templates.astype(np.float32, casting='safe') # Aggregate useful parameters/variables for handy access in downstream functions From 5351a1f10888011f82c60b89bc806943b2823508 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Tue, 25 Apr 2023 15:46:29 +0200 Subject: [PATCH 05/17] Patch --- src/spikeinterface/core/sparsity.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 502d87e9dd..5a116cb0f4 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -9,11 +9,12 @@ number of channels. * "radius": radius around the best channel. Use the 'radius_um' argument to specify the radius in um * "snr": threshold based on template signal-to-noise ratio. Use the 'threshold' argument - to specify the SNR threshold. + to specify the SNR threshold (in units of noise levels) * "ptp": threshold based on the peak-to-peak values on every channels. Use the 'threshold' argument - to specify the ptp threshold + to specify the ptp threshold (in units of noise levels) * "energy": threshold based on the expected energy that should be present on the channels, - given their noise levels. Use the 'threshold' argument to specify the SNR threshold + given their noise levels. Use the 'threshold' argument to specify the SNR threshold + (in units of noise levels) * "by_property": sparsity is given by a property of the recording and sorting(e.g. 'group'). Use the 'by_property' argument to specify the property name. @@ -214,8 +215,9 @@ def from_ptp(cls, we, threshold): mask = np.zeros((we.unit_ids.size, we.channel_ids.size), dtype='bool') templates_ptps = np.ptp(we.get_all_templates(), axis=1) + noise = get_noise_levels(we.recording, return_scaled=we.return_scaled) for unit_ind, unit_id in enumerate(we.unit_ids): - chan_inds = np.nonzero(templates_ptps[unit_ind] >= threshold) + chan_inds = np.nonzero(templates_ptps[unit_ind] / noise >= threshold) mask[unit_ind, chan_inds] = True return cls(mask, we.unit_ids, we.channel_ids) From 7deca6aa94855683f6d699b869d0f4c336ad8dca Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 16 May 2023 11:44:28 +0000 Subject: [PATCH 06/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/__init__.py | 1 + src/spikeinterface/core/sparsity.py | 6 +- .../benchmark/benchmark_matching.py | 4 +- .../sortingcomponents/matching/circus.py | 197 +++++++++--------- .../sortingcomponents/matching/main.py | 54 +++-- .../sortingcomponents/matching/naive.py | 59 +++--- .../sortingcomponents/matching/tdc.py | 10 +- .../sortingcomponents/matching/wobble.py | 16 +- 8 files changed, 182 insertions(+), 165 deletions(-) diff --git a/src/spikeinterface/__init__.py b/src/spikeinterface/__init__.py index 5f2086ede7..99325d7d65 100644 --- a/src/spikeinterface/__init__.py +++ b/src/spikeinterface/__init__.py @@ -9,6 +9,7 @@ from .core import * import warnings + warnings.filterwarnings("ignore", message="distutils Version classes are deprecated") warnings.filterwarnings("ignore", message="the imp module is deprecated") diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index f6c7604d16..4c3680b021 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -12,8 +12,8 @@ to specify the SNR threshold (in units of noise levels) * "ptp": threshold based on the peak-to-peak values on every channels. Use the 'threshold' argument to specify the ptp threshold (in units of noise levels) - * "energy": threshold based on the expected energy that should be present on the channels, - given their noise levels. Use the 'threshold' argument to specify the SNR threshold + * "energy": threshold based on the expected energy that should be present on the channels, + given their noise levels. Use the 'threshold' argument to specify the SNR threshold (in units of noise levels) * "by_property": sparsity is given by a property of the recording and sorting(e.g. 'group'). Use the 'by_property' argument to specify the property name. @@ -214,7 +214,7 @@ def from_ptp(cls, we, threshold): Use the 'threshold' argument to specify the SNR threshold. """ - mask = np.zeros((we.unit_ids.size, we.channel_ids.size), dtype='bool') + mask = np.zeros((we.unit_ids.size, we.channel_ids.size), dtype="bool") templates_ptps = np.ptp(we.get_all_templates(), axis=1) noise = get_noise_levels(we.recording, return_scaled=we.return_scaled) for unit_ind, unit_id in enumerate(we.unit_ids): diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py index 450bb6adb1..e943b1281e 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py @@ -62,7 +62,9 @@ def __del__(self): def run(self): t_start = time.time() - self.spikes = find_spikes_from_templates(self.recording, self.we, method=self.method, method_kwargs=self.method_kwargs, **self.job_kwargs) + self.spikes = find_spikes_from_templates( + self.recording, self.we, method=self.method, method_kwargs=self.method_kwargs, **self.job_kwargs + ) self.run_time = time.time() - t_start self.sorting = NumpySorting.from_times_labels( self.spikes["sample_index"], self.spikes["cluster_index"], self.sampling_rate diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index 56a9063b5e..c4e1716640 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -133,17 +133,16 @@ def _freq_domain_conv(in1, in2, axes, shape, cache, calc_fast_len=True): def compute_overlaps(templates, num_samples, num_channels): - num_templates = len(templates) size = 2 * num_samples - 1 - all_delays = list(range(0, num_samples+1)) + all_delays = list(range(0, num_samples + 1)) overlaps = {} - + for delay in all_delays: source = templates[:, :delay, :].reshape(num_templates, -1) - target = templates[:, num_samples-delay:, :].reshape(num_templates, -1) + target = templates[:, num_samples - delay :, :].reshape(num_templates, -1) overlaps[delay] = scipy.sparse.csr_matrix(source.dot(target.T)) if delay < num_samples: @@ -189,41 +188,39 @@ class CircusOMPPeeler(BaseTemplateMatchingEngine): """ _default_params = { - 'amplitudes' : [0.6, 2], - 'omp_min_sps' : 0.2, - 'overlaps' : None, - 'norms' : None, - 'random_chunk_kwargs': {}, - 'noise_levels': None, - 'ignored_ids' : [] + "amplitudes": [0.6, 2], + "omp_min_sps": 0.2, + "overlaps": None, + "norms": None, + "random_chunk_kwargs": {}, + "noise_levels": None, + "ignored_ids": [], } @classmethod def _prepare_templates(cls, d): - - num_samples = d['num_samples'] - num_channels = d['num_channels'] - templates = d['templates'] + num_samples = d["num_samples"] + num_channels = d["num_channels"] + templates = d["templates"] num_templates = len(templates) - d['circus_templates'] = d['templates'].copy() - d['norms'] = np.zeros(num_templates, dtype=np.float32) + d["circus_templates"] = d["templates"].copy() + d["norms"] = np.zeros(num_templates, dtype=np.float32) for unit_ind in range(num_templates): - d['norms'][unit_ind] = np.linalg.norm(d['circus_templates'][unit_ind]) - d['circus_templates'][unit_ind] /= d['norms'][unit_ind] + d["norms"][unit_ind] = np.linalg.norm(d["circus_templates"][unit_ind]) + d["circus_templates"][unit_ind] /= d["norms"][unit_ind] return d @classmethod def _compress_templates(cls, d): - - templates = d.pop('circus_templates') + templates = d.pop("circus_templates") num_templates = len(templates) - d['circus_templates'] = {} - + d["circus_templates"] = {} + for unit_ind in range(num_templates): - active_channels = d['sparsity_mask'][unit_ind] - d['circus_templates'][unit_ind] = templates[unit_ind][:, active_channels] + active_channels = d["sparsity_mask"][unit_ind] + d["circus_templates"][unit_ind] = templates[unit_ind][:, active_channels] return d @@ -232,13 +229,13 @@ def initialize_and_check_kwargs(cls, recording, kwargs): d = cls._default_params.copy() d.update(kwargs) - for v in ['omp_min_sps']: - assert (d[v] >= 0) and (d[v] <= 1), f'{v} should be in [0, 1]' - - d['num_channels'] = d['waveform_extractor'].recording.get_num_channels() - d['num_samples'] = d['waveform_extractor'].nsamples - d['num_templates'] = len(d['templates']) - d['sampling_frequency'] = d['waveform_extractor'].recording.get_sampling_frequency() + for v in ["omp_min_sps"]: + assert (d[v] >= 0) and (d[v] <= 1), f"{v} should be in [0, 1]" + + d["num_channels"] = d["waveform_extractor"].recording.get_num_channels() + d["num_samples"] = d["waveform_extractor"].nsamples + d["num_templates"] = len(d["templates"]) + d["sampling_frequency"] = d["waveform_extractor"].recording.get_sampling_frequency() d["num_channels"] = d["waveform_extractor"].recording.get_num_channels() d["num_samples"] = d["waveform_extractor"].nsamples @@ -248,15 +245,15 @@ def initialize_and_check_kwargs(cls, recording, kwargs): d = cls._prepare_templates(d) - if d['overlaps'] is None: - d['overlaps'] = compute_overlaps(d['circus_templates'], d['num_samples'], d['num_channels']) + if d["overlaps"] is None: + d["overlaps"] = compute_overlaps(d["circus_templates"], d["num_samples"], d["num_channels"]) d = cls._compress_templates(d) d["ignored_ids"] = np.array(d["ignored_ids"]) - omp_min_sps = d['omp_min_sps'] - d['stop_criteria'] = omp_min_sps * np.sqrt(d['noise_levels'].sum() * d['num_samples']) + omp_min_sps = d["omp_min_sps"] + d["stop_criteria"] = omp_min_sps * np.sqrt(d["noise_levels"].sum() * d["num_samples"]) return d @@ -271,21 +268,21 @@ def get_margin(cls, recording, kwargs): @classmethod def main_function(cls, traces, d): - templates = d['circus_templates'] - num_templates = d['num_templates'] - num_channels = d['num_channels'] - num_samples = d['num_samples'] - overlaps = d['overlaps'] - norms = d['norms'] - nbefore = d['nbefore'] - nafter = d['nafter'] + templates = d["circus_templates"] + num_templates = d["num_templates"] + num_channels = d["num_channels"] + num_samples = d["num_samples"] + overlaps = d["overlaps"] + norms = d["norms"] + nbefore = d["nbefore"] + nafter = d["nafter"] omp_tol = np.finfo(np.float32).eps num_samples = d["nafter"] + d["nbefore"] neighbor_window = num_samples - 1 - min_amplitude, max_amplitude = d['amplitudes'] - sparsities = d['sparsity_mask'] - ignored_ids = d['ignored_ids'] - stop_criteria = d['stop_criteria'] + min_amplitude, max_amplitude = d["amplitudes"] + sparsities = d["sparsity_mask"] + ignored_ids = d["ignored_ids"] + stop_criteria = d["stop_criteria"] if "cached_fft_kernels" not in d: d["cached_fft_kernels"] = {"fshape": 0} @@ -474,51 +471,50 @@ class CircusPeeler(BaseTemplateMatchingEngine): """ _default_params = { - 'peak_sign': 'neg', - 'exclude_sweep_ms': 0.1, - 'jitter_ms' : 0.1, - 'detect_threshold': 5, - 'noise_levels': None, - 'random_chunk_kwargs': {}, - 'max_amplitude' : 1.5, - 'min_amplitude' : 0.5, - 'use_sparse_matrix_threshold' : 0.25, + "peak_sign": "neg", + "exclude_sweep_ms": 0.1, + "jitter_ms": 0.1, + "detect_threshold": 5, + "noise_levels": None, + "random_chunk_kwargs": {}, + "max_amplitude": 1.5, + "min_amplitude": 0.5, + "use_sparse_matrix_threshold": 0.25, } @classmethod def _prepare_templates(cls, d): - - num_samples = d['num_samples'] - num_channels = d['num_channels'] - num_templates = d['num_templates'] - max_amplitude = d['max_amplitude'] - min_amplitude = d['min_amplitude'] - use_sparse_matrix_threshold = d['use_sparse_matrix_threshold'] - d['circus_templates'] = d['templates'].copy() - d['norms'] = np.zeros(num_templates, dtype=np.float32) + num_samples = d["num_samples"] + num_channels = d["num_channels"] + num_templates = d["num_templates"] + max_amplitude = d["max_amplitude"] + min_amplitude = d["min_amplitude"] + use_sparse_matrix_threshold = d["use_sparse_matrix_threshold"] + d["circus_templates"] = d["templates"].copy() + d["norms"] = np.zeros(num_templates, dtype=np.float32) for unit_ind in range(num_templates): - d['norms'][unit_ind] = np.linalg.norm(d['circus_templates'][unit_ind]) - d['circus_templates'][unit_ind] /= d['norms'][unit_ind] - + d["norms"][unit_ind] = np.linalg.norm(d["circus_templates"][unit_ind]) + d["circus_templates"][unit_ind] /= d["norms"][unit_ind] + return d @classmethod def _compress_templates(cls, d): - circus_templates = d.pop('circus_templates') + circus_templates = d.pop("circus_templates") num_templates = len(circus_templates) - num_samples = d['num_samples'] - num_channels = d['num_channels'] + num_samples = d["num_samples"] + num_channels = d["num_channels"] circus_templates = circus_templates.reshape(num_templates, -1) - nnz = np.sum(circus_templates != 0)/(num_templates * num_samples * num_channels) - if nnz <= d['use_sparse_matrix_threshold']: + nnz = np.sum(circus_templates != 0) / (num_templates * num_samples * num_channels) + if nnz <= d["use_sparse_matrix_threshold"]: circus_templates = scipy.sparse.csr_matrix(circus_templates) - d['is_dense'] = False + d["is_dense"] = False else: parameters["is_dense"] = True - d['circus_templates'] = circus_templates + d["circus_templates"] = circus_templates return d @classmethod @@ -544,12 +540,11 @@ def _cost_function_mcc(cls, bounds, good, bad, delta_amplitude, alpha): @classmethod def _optimize_amplitudes(cls, noise_snippets, d): - - waveform_extractor = d['waveform_extractor'] - templates = d['circus_templates'] - num_templates = d['num_templates'] - max_amplitude = d['max_amplitude'] - min_amplitude = d['min_amplitude'] + waveform_extractor = d["waveform_extractor"] + templates = d["circus_templates"] + num_templates = d["num_templates"] + max_amplitude = d["max_amplitude"] + min_amplitude = d["min_amplitude"] alpha = 0.5 norms = parameters["norms"] all_units = list(waveform_extractor.sorting.unit_ids) @@ -580,12 +575,12 @@ def initialize_and_check_kwargs(cls, recording, kwargs): default_parameters = cls._default_params.copy() default_parameters.update(kwargs) - for v in ['use_sparse_matrix_threshold']: - assert (d[v] >= 0) and (d[v] <= 1), f'{v} should be in [0, 1]' - - d['num_channels'] = d['waveform_extractor'].recording.get_num_channels() - d['num_samples'] = d['waveform_extractor'].nsamples - d['num_templates'] = len(d['templates']) + for v in ["use_sparse_matrix_threshold"]: + assert (d[v] >= 0) and (d[v] <= 1), f"{v} should be in [0, 1]" + + d["num_channels"] = d["waveform_extractor"].recording.get_num_channels() + d["num_samples"] = d["waveform_extractor"].nsamples + d["num_templates"] = len(d["templates"]) default_parameters["num_channels"] = default_parameters["waveform_extractor"].recording.get_num_channels() default_parameters["num_samples"] = default_parameters["waveform_extractor"].nsamples @@ -598,7 +593,7 @@ def initialize_and_check_kwargs(cls, recording, kwargs): ) d = cls._prepare_templates(d) - d['overlaps'] = compute_overlaps(d['circus_templates'], d['num_samples'], d['num_channels']) + d["overlaps"] = compute_overlaps(d["circus_templates"], d["num_samples"], d["num_channels"]) d = cls._compress_templates(d) default_parameters = cls._prepare_templates(default_parameters) @@ -649,18 +644,18 @@ def get_margin(cls, recording, kwargs): @classmethod def main_function(cls, traces, d): - peak_sign = d['peak_sign'] - abs_threholds = d['abs_threholds'] - exclude_sweep_size = d['exclude_sweep_size'] - templates = d['circus_templates'] - num_templates = d['num_templates'] - num_channels = d['num_channels'] - overlaps = d['overlaps'] - margin = d['margin'] - norms = d['norms'] - jitter = d['jitter'] - patch_sizes = d['patch_sizes'] - num_samples = d['nafter'] + d['nbefore'] + peak_sign = d["peak_sign"] + abs_threholds = d["abs_threholds"] + exclude_sweep_size = d["exclude_sweep_size"] + templates = d["circus_templates"] + num_templates = d["num_templates"] + num_channels = d["num_channels"] + overlaps = d["overlaps"] + margin = d["margin"] + norms = d["norms"] + jitter = d["jitter"] + patch_sizes = d["patch_sizes"] + num_samples = d["nafter"] + d["nbefore"] neighbor_window = num_samples - 1 amplitudes = d["amplitudes"] sym_patch = d["sym_patch"] diff --git a/src/spikeinterface/sortingcomponents/matching/main.py b/src/spikeinterface/sortingcomponents/matching/main.py index a9402ce71b..ab92316bfd 100644 --- a/src/spikeinterface/sortingcomponents/matching/main.py +++ b/src/spikeinterface/sortingcomponents/matching/main.py @@ -5,9 +5,17 @@ from spikeinterface.core import get_chunk_with_margin, compute_sparsity, WaveformExtractor -def find_spikes_from_templates(recording, waveform_extractor, sparsity={'method' : 'ptp', 'threshold' : 1}, - templates=None, sparsity_mask=None, method='naive', - method_kwargs={}, extra_outputs=False, **job_kwargs): +def find_spikes_from_templates( + recording, + waveform_extractor, + sparsity={"method": "ptp", "threshold": 1}, + templates=None, + sparsity_mask=None, + method="naive", + method_kwargs={}, + extra_outputs=False, + **job_kwargs, +): """Find spike from a recording from given templates. Parameters @@ -23,7 +31,7 @@ def find_spikes_from_templates(recording, waveform_extractor, sparsity={'method' If provided, then the templates are used instead of the ones from the waveform_extractor sparsity_mask: np.array, bool If provided, the sparsity mask used for the provided templates - method: str + method: str Which method to use ('naive' | 'tridesclous' | 'circus' | 'circus-omp' | 'wobble') method_kwargs: dict, optional Keyword arguments for the chosen method @@ -50,10 +58,11 @@ def find_spikes_from_templates(recording, waveform_extractor, sparsity={'method' job_kwargs = fix_job_kwargs(job_kwargs) method_class = matching_methods[method] - + # initialize the templates - method_kwargs = method_class.initialize_and_sparsify_templates(method_kwargs, waveform_extractor, sparsity, - templates, sparsity_mask) + method_kwargs = method_class.initialize_and_sparsify_templates( + method_kwargs, waveform_extractor, sparsity, templates, sparsity_mask + ) # initialize method_kwargs = method_class.initialize_and_check_kwargs(recording, method_kwargs) @@ -131,35 +140,38 @@ def _find_spikes_chunk(segment_index, start_frame, end_frame, worker_ctx): # generic class for template engine class BaseTemplateMatchingEngine: - @classmethod def initialize_and_sparsify_templates(cls, kwargs, waveform_extractor, sparsity, templates, sparsity_mask): assert isinstance(waveform_extractor, WaveformExtractor) - kwargs.update({'nbefore' : waveform_extractor.nbefore, - 'nafter' : waveform_extractor.nafter, - 'sampling_frequency' : waveform_extractor.sampling_frequency}) + kwargs.update( + { + "nbefore": waveform_extractor.nbefore, + "nafter": waveform_extractor.nafter, + "sampling_frequency": waveform_extractor.sampling_frequency, + } + ) num_channels = waveform_extractor.get_num_channels() if templates is not None: - kwargs['templates'] = templates.copy() + kwargs["templates"] = templates.copy() num_templates = len(templates) if sparsity_mask is None: - kwargs['sparsity_mask'] = np.ones((num_templates, num_channels), dtype=bool) + kwargs["sparsity_mask"] = np.ones((num_templates, num_channels), dtype=bool) else: - kwargs['templates'] = waveform_extractor.get_all_templates().copy() - num_templates = len(kwargs['templates']) + kwargs["templates"] = waveform_extractor.get_all_templates().copy() + num_templates = len(kwargs["templates"]) if waveform_extractor.is_sparse(): - kwargs['sparsity_mask'] = waveform_extractor.sparsity.mask + kwargs["sparsity_mask"] = waveform_extractor.sparsity.mask else: if sparsity is not None: - kwargs['sparsity_mask'] = compute_sparsity(waveform_extractor, **sparsity).mask + kwargs["sparsity_mask"] = compute_sparsity(waveform_extractor, **sparsity).mask else: - kwargs['sparsity_mask'] = np.ones((num_templates, num_channels), dtype=bool) + kwargs["sparsity_mask"] = np.ones((num_templates, num_channels), dtype=bool) for unit_ind in range(num_templates): - active_channels = kwargs['sparsity_mask'][unit_ind] - kwargs['templates'][unit_ind][:, ~active_channels] = 0 + active_channels = kwargs["sparsity_mask"][unit_ind] + kwargs["templates"][unit_ind][:, ~active_channels] = 0 return kwargs @@ -174,7 +186,7 @@ def serialize_method_kwargs(cls, kwargs): """This function serializes kwargs to distribute them to workers""" kwargs = dict(kwargs) # remove waveform_extractor - kwargs.pop('waveform_extractor') + kwargs.pop("waveform_extractor") return kwargs @classmethod diff --git a/src/spikeinterface/sortingcomponents/matching/naive.py b/src/spikeinterface/sortingcomponents/matching/naive.py index 0bb555e919..0073858200 100644 --- a/src/spikeinterface/sortingcomponents/matching/naive.py +++ b/src/spikeinterface/sortingcomponents/matching/naive.py @@ -2,8 +2,14 @@ import numpy as np from spikeinterface.core import WaveformExtractor -from spikeinterface.core import get_noise_levels, get_channel_distances, get_chunk_with_margin, get_random_data_chunks, compute_sparsity -from spikeinterface.postprocessing import (get_template_channel_sparsity, get_template_extremum_channel) +from spikeinterface.core import ( + get_noise_levels, + get_channel_distances, + get_chunk_with_margin, + get_random_data_chunks, + compute_sparsity, +) +from spikeinterface.postprocessing import get_template_channel_sparsity, get_template_extremum_channel from spikeinterface.sortingcomponents.peak_detection import DetectPeakLocallyExclusive @@ -30,27 +36,27 @@ class NaiveMatching(BaseTemplateMatchingEngine): """ default_params = { - 'peak_sign': 'neg', - 'exclude_sweep_ms': 0.1, - 'detect_threshold': 5, - 'noise_levels': None, - 'local_radius_um': 75, - 'random_chunk_kwargs': {} + "peak_sign": "neg", + "exclude_sweep_ms": 0.1, + "detect_threshold": 5, + "noise_levels": None, + "local_radius_um": 75, + "random_chunk_kwargs": {}, } @classmethod def initialize_and_check_kwargs(cls, recording, kwargs): d = cls.default_params.copy() d.update(kwargs) - - if d['noise_levels'] is None: - d['noise_levels'] = get_noise_levels(recording, **d['random_chunk_kwargs']) - d['abs_thresholds'] = d['noise_levels'] * d['detect_threshold'] + if d["noise_levels"] is None: + d["noise_levels"] = get_noise_levels(recording, **d["random_chunk_kwargs"]) + + d["abs_thresholds"] = d["noise_levels"] * d["detect_threshold"] channel_distance = get_channel_distances(recording) - d['neighbours_mask'] = channel_distance < d['local_radius_um'] - d['exclude_sweep_size'] = int(d['exclude_sweep_ms'] * d['sampling_frequency'] / 1000.) + d["neighbours_mask"] = channel_distance < d["local_radius_um"] + d["exclude_sweep_size"] = int(d["exclude_sweep_ms"] * d["sampling_frequency"] / 1000.0) return d @@ -60,26 +66,27 @@ def get_margin(cls, recording, kwargs): return margin @classmethod - def unserialize_in_worker(cls, kwargs): + def unserialize_in_worker(cls, kwargs): return kwargs @classmethod def main_function(cls, traces, method_kwargs): - - peak_sign = method_kwargs['peak_sign'] - abs_thresholds = method_kwargs['abs_thresholds'] - exclude_sweep_size = method_kwargs['exclude_sweep_size'] - neighbours_mask = method_kwargs['neighbours_mask'] - templates = method_kwargs['templates'] - nbefore = method_kwargs['nbefore'] - nafter = method_kwargs['nafter'] - margin = method_kwargs['margin'] - + peak_sign = method_kwargs["peak_sign"] + abs_thresholds = method_kwargs["abs_thresholds"] + exclude_sweep_size = method_kwargs["exclude_sweep_size"] + neighbours_mask = method_kwargs["neighbours_mask"] + templates = method_kwargs["templates"] + nbefore = method_kwargs["nbefore"] + nafter = method_kwargs["nafter"] + margin = method_kwargs["margin"] + if margin > 0: peak_traces = traces[margin:-margin, :] else: peak_traces = traces - peak_sample_ind, peak_chan_ind = DetectPeakLocallyExclusive.detect_peaks(peak_traces, peak_sign, abs_thresholds, exclude_sweep_size, neighbours_mask) + peak_sample_ind, peak_chan_ind = DetectPeakLocallyExclusive.detect_peaks( + peak_traces, peak_sign, abs_thresholds, exclude_sweep_size, neighbours_mask + ) peak_sample_ind += margin spikes = np.zeros(peak_sample_ind.size, dtype=spike_dtype) diff --git a/src/spikeinterface/sortingcomponents/matching/tdc.py b/src/spikeinterface/sortingcomponents/matching/tdc.py index 4f3714f458..163ceca2a2 100644 --- a/src/spikeinterface/sortingcomponents/matching/tdc.py +++ b/src/spikeinterface/sortingcomponents/matching/tdc.py @@ -89,7 +89,7 @@ def initialize_and_check_kwargs(cls, recording, kwargs): channel_distance = get_channel_distances(recording) d['neighbours_mask'] = channel_distance < d['local_radius_um'] - + extremum_channel = get_template_extremum_channel(we, peak_sign=d['peak_sign'], outputs='index') # as numpy vector extremum_channel = np.array([extremum_channel[unit_id] for unit_id in unit_ids], dtype="int64") @@ -221,17 +221,17 @@ def _tdc_find_spikes(traces, d, level=0): ## pure numpy with cluster+channel spasity # union_channels, = np.nonzero(np.any(d['sparsity_mask'][possible_clusters, :], axis=0)) # distances = np.sum(np.sum((templates[possible_clusters][:, :, union_channels] - wf[: , union_channels][None, : :])**2, axis=1), axis=1) - + ## numba with cluster+channel spasity union_channels = np.any(d['sparsity_mask'][possible_clusters, :], axis=0) # distances = numba_sparse_dist(wf, templates, union_channels, possible_clusters) distances = numba_sparse_dist(wf_short, templates_short, union_channels, possible_clusters) - - + + # DEBUG #~ ind = np.argmin(distances) #~ cluster_index = possible_clusters[ind] - + for ind in np.argsort(distances)[:d['num_template_try']]: cluster_index = possible_clusters[ind] diff --git a/src/spikeinterface/sortingcomponents/matching/wobble.py b/src/spikeinterface/sortingcomponents/matching/wobble.py index 620d9c6461..9e7abbb352 100644 --- a/src/spikeinterface/sortingcomponents/matching/wobble.py +++ b/src/spikeinterface/sortingcomponents/matching/wobble.py @@ -335,9 +335,9 @@ def initialize_and_check_kwargs(cls, recording, kwargs): """ d = cls.default_params.copy() d.update(kwargs) - parameters = d.get('parameters', {}) - templates = d['templates'] - templates = templates.astype(np.float32, casting='safe') + parameters = d.get("parameters", {}) + templates = d["templates"] + templates = templates.astype(np.float32, casting="safe") # Aggregate useful parameters/variables for handy access in downstream functions params = WobbleParameters(**parameters) @@ -362,10 +362,10 @@ def initialize_and_check_kwargs(cls, recording, kwargs): ) # Pack initial data into kwargs - d['params'] = params - d['template_meta'] = template_meta - d['sparsity'] = sparsity - d['template_data'] = template_data + d["params"] = params + d["template_meta"] = template_meta + d["sparsity"] = sparsity + d["template_data"] = template_data return d @classmethod @@ -373,7 +373,7 @@ def serialize_method_kwargs(cls, kwargs): # This function does nothing without a waveform extractor -- candidate for refactor kwargs = dict(kwargs) # remove waveform_extractor - kwargs.pop('waveform_extractor') + kwargs.pop("waveform_extractor") return kwargs @classmethod From 6d281d0d565795627bafe1c70b553bd5d621ee0b Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 16 May 2023 13:51:54 +0200 Subject: [PATCH 07/17] fix tdc --- .../sortingcomponents/matching/tdc.py | 79 +++++++++---------- 1 file changed, 36 insertions(+), 43 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/matching/tdc.py b/src/spikeinterface/sortingcomponents/matching/tdc.py index 163ceca2a2..fc02237581 100644 --- a/src/spikeinterface/sortingcomponents/matching/tdc.py +++ b/src/spikeinterface/sortingcomponents/matching/tdc.py @@ -1,7 +1,11 @@ import numpy as np import scipy -from spikeinterface.core import (WaveformExtractor, get_noise_levels, get_channel_distances, - get_template_extremum_channel) +from spikeinterface.core import ( + WaveformExtractor, + get_noise_levels, + get_channel_distances, + get_template_extremum_channel, +) from spikeinterface.sortingcomponents.peak_detection import DetectPeakLocallyExclusive @@ -40,18 +44,18 @@ class TridesclousPeeler(BaseTemplateMatchingEngine): """ default_params = { - 'peak_sign': 'neg', - 'peak_shift_ms': 0.2, - 'detect_threshold': 5, - 'noise_levels': None, - 'local_radius_um': 75, - 'random_chunk_kwargs': {}, - 'num_closest' : 5, - 'sample_shift': 3, - 'ms_before': 0.8, - 'ms_after': 1.2, - 'num_peeler_loop': 2, - 'num_template_try' : 1 + "peak_sign": "neg", + "peak_shift_ms": 0.2, + "detect_threshold": 5, + "noise_levels": None, + "local_radius_um": 75, + "random_chunk_kwargs": {}, + "num_closest": 5, + "sample_shift": 3, + "ms_before": 0.8, + "ms_after": 1.2, + "num_peeler_loop": 2, + "num_template_try": 1, } @classmethod @@ -61,15 +65,15 @@ def initialize_and_check_kwargs(cls, recording, kwargs): d = cls.default_params.copy() d.update(kwargs) - we = d['waveform_extractor'] - templates = d['templates'] + we = d["waveform_extractor"] + templates = d["templates"] unit_ids = we.unit_ids channel_ids = we.channel_ids sr = we.sampling_frequency - nbefore_short = int(d['ms_before'] * sr / 1000.) - nafter_short = int(d['ms_before'] * sr / 1000.) + nbefore_short = int(d["ms_before"] * sr / 1000.0) + nafter_short = int(d["ms_before"] * sr / 1000.0) assert nbefore_short <= we.nbefore assert nafter_short <= we.nafter d["nbefore_short"] = nbefore_short @@ -81,16 +85,18 @@ def initialize_and_check_kwargs(cls, recording, kwargs): templates_short = templates[:, slice(s0, s1), :].copy() d["templates_short"] = templates_short - d['peak_shift'] = int(d['peak_shift_ms'] / 1000 * sr) + d["peak_shift"] = int(d["peak_shift_ms"] / 1000 * sr) - if d['noise_levels'] is None: - print('TridesclousPeeler : noise should be computed outside') - d['noise_levels'] = get_noise_levels(recording, **d['random_chunk_kwargs']) + if d["noise_levels"] is None: + print("TridesclousPeeler : noise should be computed outside") + d["noise_levels"] = get_noise_levels(recording, **d["random_chunk_kwargs"]) + + d["abs_threholds"] = d["noise_levels"] * d["detect_threshold"] channel_distance = get_channel_distances(recording) - d['neighbours_mask'] = channel_distance < d['local_radius_um'] + d["neighbours_mask"] = channel_distance < d["local_radius_um"] - extremum_channel = get_template_extremum_channel(we, peak_sign=d['peak_sign'], outputs='index') + extremum_channel = get_template_extremum_channel(we, peak_sign=d["peak_sign"], outputs="index") # as numpy vector extremum_channel = np.array([extremum_channel[unit_id] for unit_id in unit_ids], dtype="int64") d["extremum_channel"] = extremum_channel @@ -113,7 +119,7 @@ def initialize_and_check_kwargs(cls, recording, kwargs): closest_u = np.array(closest_u[: d["num_closest"]]) # compute unitary discriminent vector - chans, = np.nonzero(d['sparsity_mask'][unit_ind, :]) + (chans,) = np.nonzero(d["sparsity_mask"][unit_ind, :]) template_sparse = templates[unit_ind, :, :][:, chans] closest_vec = [] # against N closets @@ -218,32 +224,19 @@ def _tdc_find_spikes(traces, d, level=0): # ~ wf = traces[s0:s1, :] - ## pure numpy with cluster+channel spasity - # union_channels, = np.nonzero(np.any(d['sparsity_mask'][possible_clusters, :], axis=0)) - # distances = np.sum(np.sum((templates[possible_clusters][:, :, union_channels] - wf[: , union_channels][None, : :])**2, axis=1), axis=1) - - ## numba with cluster+channel spasity - union_channels = np.any(d['sparsity_mask'][possible_clusters, :], axis=0) - # distances = numba_sparse_dist(wf, templates, union_channels, possible_clusters) - distances = numba_sparse_dist(wf_short, templates_short, union_channels, possible_clusters) - - - # DEBUG - #~ ind = np.argmin(distances) - #~ cluster_index = possible_clusters[ind] - - for ind in np.argsort(distances)[:d['num_template_try']]: - cluster_index = possible_clusters[ind] + s0 = sample_index - d["nbefore_short"] + s1 = sample_index + d["nafter_short"] + wf_short = traces[s0:s1, :] ## pure numpy with cluster spasity # distances = np.sum(np.sum((templates[possible_clusters, :, :] - wf[None, : , :])**2, axis=1), axis=1) ## pure numpy with cluster+channel spasity - # union_channels, = np.nonzero(np.any(d['template_sparsity'][possible_clusters, :], axis=0)) + # union_channels, = np.nonzero(np.any(d['sparsity_mask'][possible_clusters, :], axis=0)) # distances = np.sum(np.sum((templates[possible_clusters][:, :, union_channels] - wf[: , union_channels][None, : :])**2, axis=1), axis=1) ## numba with cluster+channel spasity - union_channels = np.any(d["template_sparsity"][possible_clusters, :], axis=0) + union_channels = np.any(d["sparsity_mask"][possible_clusters, :], axis=0) # distances = numba_sparse_dist(wf, templates, union_channels, possible_clusters) distances = numba_sparse_dist(wf_short, templates_short, union_channels, possible_clusters) From 71214862d4a116549340dd8d4dfffb1bb8f526ac Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 24 May 2023 08:32:17 +0000 Subject: [PATCH 08/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- doc/modules/core.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/modules/core.rst b/doc/modules/core.rst index a1c28ecfaf..9af69768dd 100644 --- a/doc/modules/core.rst +++ b/doc/modules/core.rst @@ -247,7 +247,7 @@ Finally, an existing :py:class:`~spikeinterface.core.WaveformExtractor` can be s **IMPORTANT:** to load a waveform extractor object from disk, it needs to be able to reload the associated -:code:`sorting` object (the :code:`recording` is optional, using :code:`with_recording=False`). +:code:`sorting` object (the :code:`recording` is optional, using :code:`with_recording=False`). In order to make a waveform folder portable (e.g. copied to another location or machine), one can do: .. code-block:: python From d9efa43025679e46f8fafb4915a20fbf70bdf5c3 Mon Sep 17 00:00:00 2001 From: yger Date: Wed, 24 May 2023 11:54:13 +0200 Subject: [PATCH 09/17] WIP for matching objects --- .../sortingcomponents/matching/circus.py | 38 +++--- .../sortingcomponents/matching/main.py | 122 ++++++++++++------ .../sortingcomponents/matching/naive.py | 10 +- .../sortingcomponents/matching/tdc.py | 55 ++++---- 4 files changed, 132 insertions(+), 93 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index c4e1716640..8ee518a85a 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -203,7 +203,7 @@ def _prepare_templates(cls, d): num_channels = d["num_channels"] templates = d["templates"] num_templates = len(templates) - d["circus_templates"] = d["templates"].copy() + d["circus_templates"] = d["templates"].data.copy() d["norms"] = np.zeros(num_templates, dtype=np.float32) for unit_ind in range(num_templates): @@ -219,7 +219,7 @@ def _compress_templates(cls, d): d["circus_templates"] = {} for unit_ind in range(num_templates): - active_channels = d["sparsity_mask"][unit_ind] + active_channels = d["templates"].sparsity_mask[unit_ind] d["circus_templates"][unit_ind] = templates[unit_ind][:, active_channels] return d @@ -232,21 +232,15 @@ def initialize_and_check_kwargs(cls, recording, kwargs): for v in ["omp_min_sps"]: assert (d[v] >= 0) and (d[v] <= 1), f"{v} should be in [0, 1]" - d["num_channels"] = d["waveform_extractor"].recording.get_num_channels() - d["num_samples"] = d["waveform_extractor"].nsamples - d["num_templates"] = len(d["templates"]) - d["sampling_frequency"] = d["waveform_extractor"].recording.get_sampling_frequency() - - d["num_channels"] = d["waveform_extractor"].recording.get_num_channels() - d["num_samples"] = d["waveform_extractor"].nsamples - d["nbefore"] = d["waveform_extractor"].nbefore - d["nafter"] = d["waveform_extractor"].nafter - d["sampling_frequency"] = d["waveform_extractor"].recording.get_sampling_frequency() + num_templates = d["templates"].num_templates + num_channels = d["templates"].num_channels + num_samples = d["templates"].nsamples + d["sampling_frequency"] = recording.get_sampling_frequency() d = cls._prepare_templates(d) if d["overlaps"] is None: - d["overlaps"] = compute_overlaps(d["circus_templates"], d["num_samples"], d["num_channels"]) + d["overlaps"] = compute_overlaps(d["circus_templates"], num_samples, num_channels) d = cls._compress_templates(d) @@ -263,24 +257,24 @@ def unserialize_in_worker(cls, kwargs): @classmethod def get_margin(cls, recording, kwargs): - margin = 2 * max(kwargs["nbefore"], kwargs["nafter"]) + templates = kwargs['templates'] + margin = 2 * max(templates.nbefore, templates.nafter) return margin @classmethod def main_function(cls, traces, d): templates = d["circus_templates"] - num_templates = d["num_templates"] - num_channels = d["num_channels"] - num_samples = d["num_samples"] + num_templates = d["templates"].num_templates + num_channels = d["templates"].num_channels + num_samples = d["templates"].nsamples overlaps = d["overlaps"] norms = d["norms"] - nbefore = d["nbefore"] - nafter = d["nafter"] + nbefore = d["templates"].nbefore + nafter =d["templates"].nafter omp_tol = np.finfo(np.float32).eps - num_samples = d["nafter"] + d["nbefore"] neighbor_window = num_samples - 1 min_amplitude, max_amplitude = d["amplitudes"] - sparsities = d["sparsity_mask"] + sparsity_mask = d["templates"].sparsity_mask ignored_ids = d["ignored_ids"] stop_criteria = d["stop_criteria"] @@ -312,7 +306,7 @@ def main_function(cls, traces, d): cached_fft_kernels.update({i: sp_fft.rfftn(kernel_filter, fshape, axes=axes)}) cached_fft_kernels["fshape"] = fshape[0] - fft_cache.update({"mask": sparsities[i], "template": cached_fft_kernels[i]}) + fft_cache.update({"mask": sparsity_mask[i], "template": cached_fft_kernels[i]}) convolution = fftconvolve_with_cache(dummy_filter, dummy_traces, fft_cache, axes=1, mode="valid") if len(convolution) > 0: diff --git a/src/spikeinterface/sortingcomponents/matching/main.py b/src/spikeinterface/sortingcomponents/matching/main.py index ab92316bfd..e8f467cfb8 100644 --- a/src/spikeinterface/sortingcomponents/matching/main.py +++ b/src/spikeinterface/sortingcomponents/matching/main.py @@ -4,13 +4,83 @@ from spikeinterface.core.job_tools import ChunkRecordingExecutor, fix_job_kwargs from spikeinterface.core import get_chunk_with_margin, compute_sparsity, WaveformExtractor +from threadpoolctl import threadpool_limits +import numpy as np + +from spikeinterface.core.job_tools import ChunkRecordingExecutor, fix_job_kwargs +from spikeinterface.core import get_chunk_with_margin + + +class TemplatesDictionary(object): + + def __init__( + self, + data, + unit_ids, + ms_before, + ms_after, + sparsity_mask=None + ): + + self.data = data.copy() + self.unit_ids = unit_ids + self.ms_before = ms_before + self.ms_after = ms_after + if sparsity_mask is None: + self.sparsity_mask = np.sum(data, axis=(1)) == 0 + else: + assert sparsity_mask.shape == (data.shape[0], data.shape[2]), 'sparsity_mask has the wrong shape' + self.sparsity_mask = sparsity_mask + + for i in range(len(self.data)): + active_channels = self.sparsity_mask[i] + self.data[i][:, ~active_channels] = 0 + + def __getitem__(self, template_id): + return self.data[template_id] + + def __getslice__(self, start, stop): + return self.data[start:stop] + + @property + def shape(self): + return self.data.shape + + @property + def num_channels(self): + return self.data.shape[2] + + @property + def nsamples(self): + return self.data.shape[1] + + def __len__(self): + return len(self.data) + + def get_template_extremum_channel(self, peak_sign='neg', outputs="index"): + assert peak_sign in ['neg', 'pos'], "peak_sign should be in ['neg', 'pos']" + return + + +def create_templates_from_waveform_extractor(waveform_extractor, mode='median', sparsity=None): + + if sparsity is not None and not waveform_extractor.is_sparse(): + sparsity_mask = compute_sparsity(waveform_extractor, **sparsity) + else: + sparsity_mask = None + + data = waveform_extractor.get_all_templates(mode) + unit_ids = waveform_extractor.unit_ids + ms_before = waveform_extractor.ms_before + ms_after = waveform_extractor.ms_after + return TemplatesDictionary(data, waveform_extractor.unit_ids, ms_before, ms_after, sparsity_mask) + def find_spikes_from_templates( recording, waveform_extractor, sparsity={"method": "ptp", "threshold": 1}, - templates=None, - sparsity_mask=None, + templates_dictionary=None, method="naive", method_kwargs={}, extra_outputs=False, @@ -27,10 +97,8 @@ def find_spikes_from_templates( sparsity: dict or None Parameters that should be given to sparsify the templates, if waveform_extractor is not already sparse - templates: np.array - If provided, then the templates are used instead of the ones from the waveform_extractor - sparsity_mask: np.array, bool - If provided, the sparsity mask used for the provided templates + templates_dictionary: TemplatesDictionnary + If provided, then these templates are used instead of the ones from the waveform_extractor method: str Which method to use ('naive' | 'tridesclous' | 'circus' | 'circus-omp' | 'wobble') method_kwargs: dict, optional @@ -61,7 +129,7 @@ def find_spikes_from_templates( # initialize the templates method_kwargs = method_class.initialize_and_sparsify_templates( - method_kwargs, waveform_extractor, sparsity, templates, sparsity_mask + method_kwargs, waveform_extractor, sparsity, templates ) # initialize @@ -141,37 +209,15 @@ def _find_spikes_chunk(segment_index, start_frame, end_frame, worker_ctx): # generic class for template engine class BaseTemplateMatchingEngine: @classmethod - def initialize_and_sparsify_templates(cls, kwargs, waveform_extractor, sparsity, templates, sparsity_mask): + def initialize_and_sparsify_templates(cls, kwargs, waveform_extractor, templates_dictionary, sparsity): assert isinstance(waveform_extractor, WaveformExtractor) - kwargs.update( - { - "nbefore": waveform_extractor.nbefore, - "nafter": waveform_extractor.nafter, - "sampling_frequency": waveform_extractor.sampling_frequency, - } - ) - - num_channels = waveform_extractor.get_num_channels() - - if templates is not None: - kwargs["templates"] = templates.copy() - num_templates = len(templates) - if sparsity_mask is None: - kwargs["sparsity_mask"] = np.ones((num_templates, num_channels), dtype=bool) - else: - kwargs["templates"] = waveform_extractor.get_all_templates().copy() - num_templates = len(kwargs["templates"]) - if waveform_extractor.is_sparse(): - kwargs["sparsity_mask"] = waveform_extractor.sparsity.mask - else: - if sparsity is not None: - kwargs["sparsity_mask"] = compute_sparsity(waveform_extractor, **sparsity).mask - else: - kwargs["sparsity_mask"] = np.ones((num_templates, num_channels), dtype=bool) - - for unit_ind in range(num_templates): - active_channels = kwargs["sparsity_mask"][unit_ind] - kwargs["templates"][unit_ind][:, ~active_channels] = 0 + + if templates_dictionary is not None: + templates_dictionary = create_templates_from_waveform_extractor(waveform_extractor, sparsity=sparsity) + + assert isinstance(templates_dictionary, TemplatesDictionary) + + kwargs["templates"] = templates_dictionary return kwargs @@ -185,8 +231,6 @@ def initialize_and_check_kwargs(cls, recording, kwargs): def serialize_method_kwargs(cls, kwargs): """This function serializes kwargs to distribute them to workers""" kwargs = dict(kwargs) - # remove waveform_extractor - kwargs.pop("waveform_extractor") return kwargs @classmethod diff --git a/src/spikeinterface/sortingcomponents/matching/naive.py b/src/spikeinterface/sortingcomponents/matching/naive.py index 0073858200..459a1c1245 100644 --- a/src/spikeinterface/sortingcomponents/matching/naive.py +++ b/src/spikeinterface/sortingcomponents/matching/naive.py @@ -62,7 +62,8 @@ def initialize_and_check_kwargs(cls, recording, kwargs): @classmethod def get_margin(cls, recording, kwargs): - margin = max(kwargs["nbefore"], kwargs["nafter"]) + templates = kwargs['templates'] + margin = max(templates.nbefore, templates.nafter) return margin @classmethod @@ -76,8 +77,9 @@ def main_function(cls, traces, method_kwargs): exclude_sweep_size = method_kwargs["exclude_sweep_size"] neighbours_mask = method_kwargs["neighbours_mask"] templates = method_kwargs["templates"] - nbefore = method_kwargs["nbefore"] - nafter = method_kwargs["nafter"] + sparsity_mask = templates.sparsity_mask + nbefore = templates.nbefore + nafter = templates.nafter margin = method_kwargs["margin"] if margin > 0: @@ -99,7 +101,7 @@ def main_function(cls, traces, method_kwargs): i1 = peak_sample_ind[i] + nafter waveforms = traces[i0:i1, :] - dist = np.sum(np.sum((templates - waveforms[None, :, :]) ** 2, axis=1), axis=1) + dist = np.sum(np.sum((templates.data - waveforms[None, :, :]) ** 2, axis=1), axis=1) cluster_index = np.argmin(dist) spikes["cluster_index"][i] = cluster_index diff --git a/src/spikeinterface/sortingcomponents/matching/tdc.py b/src/spikeinterface/sortingcomponents/matching/tdc.py index fc02237581..0569b83f56 100644 --- a/src/spikeinterface/sortingcomponents/matching/tdc.py +++ b/src/spikeinterface/sortingcomponents/matching/tdc.py @@ -1,7 +1,6 @@ import numpy as np import scipy from spikeinterface.core import ( - WaveformExtractor, get_noise_levels, get_channel_distances, get_template_extremum_channel, @@ -65,24 +64,25 @@ def initialize_and_check_kwargs(cls, recording, kwargs): d = cls.default_params.copy() d.update(kwargs) - we = d["waveform_extractor"] templates = d["templates"] - unit_ids = we.unit_ids - channel_ids = we.channel_ids + unit_ids = templates.unit_ids + channel_ids = recording.get_channel_ids() + channel_locations = recording.get_channel_locations() + sparsity_mask = templates.sparsity_mask - sr = we.sampling_frequency + sr = recording.sampling_frequency nbefore_short = int(d["ms_before"] * sr / 1000.0) nafter_short = int(d["ms_before"] * sr / 1000.0) - assert nbefore_short <= we.nbefore - assert nafter_short <= we.nafter + assert nbefore_short <= templates.nbefore + assert nafter_short <= templates.nafter d["nbefore_short"] = nbefore_short d["nafter_short"] = nafter_short - s0 = we.nbefore - nbefore_short - s1 = -(we.nafter - nafter_short) + s0 = templates.nbefore - nbefore_short + s1 = -(templates.nafter - nafter_short) if s1 == 0: s1 = None - templates_short = templates[:, slice(s0, s1), :].copy() + templates_short = templates.data[:, slice(s0, s1), :].copy() d["templates_short"] = templates_short d["peak_shift"] = int(d["peak_shift_ms"] / 1000 * sr) @@ -96,13 +96,11 @@ def initialize_and_check_kwargs(cls, recording, kwargs): channel_distance = get_channel_distances(recording) d["neighbours_mask"] = channel_distance < d["local_radius_um"] - extremum_channel = get_template_extremum_channel(we, peak_sign=d["peak_sign"], outputs="index") + extremum_channel = templates.get_template_extremum_channel(peak_sign=d["peak_sign"], outputs="index") # as numpy vector extremum_channel = np.array([extremum_channel[unit_id] for unit_id in unit_ids], dtype="int64") d["extremum_channel"] = extremum_channel - channel_locations = we.recording.get_channel_locations() - # TODO try it with real locaion unit_locations = channel_locations[extremum_channel] # ~ print(unit_locations) @@ -119,12 +117,12 @@ def initialize_and_check_kwargs(cls, recording, kwargs): closest_u = np.array(closest_u[: d["num_closest"]]) # compute unitary discriminent vector - (chans,) = np.nonzero(d["sparsity_mask"][unit_ind, :]) - template_sparse = templates[unit_ind, :, :][:, chans] + (chans,) = np.nonzero(sparsity_mask[unit_ind, :]) + template_sparse = templates[unit_ind][:, chans] closest_vec = [] # against N closets for u in closest_u: - vec = templates[u, :, :][:, chans] - template_sparse + vec = templates[u][:, chans] - template_sparse vec /= np.sum(vec**2) closest_vec.append((u, vec)) # against noise @@ -155,7 +153,8 @@ def unserialize_in_worker(cls, kwargs): @classmethod def get_margin(cls, recording, kwargs): - margin = 2 * (kwargs["nbefore"] + kwargs["nafter"]) + templates = kwargs['templates'] + margin = 2 * (templates.nbefore + templates.nafter) return margin @classmethod @@ -236,7 +235,7 @@ def _tdc_find_spikes(traces, d, level=0): # distances = np.sum(np.sum((templates[possible_clusters][:, :, union_channels] - wf[: , union_channels][None, : :])**2, axis=1), axis=1) ## numba with cluster+channel spasity - union_channels = np.any(d["sparsity_mask"][possible_clusters, :], axis=0) + union_channels = np.any(templates.sparsity_mask[possible_clusters, :], axis=0) # distances = numba_sparse_dist(wf, templates, union_channels, possible_clusters) distances = numba_sparse_dist(wf_short, templates_short, union_channels, possible_clusters) @@ -247,8 +246,8 @@ def _tdc_find_spikes(traces, d, level=0): for ind in np.argsort(distances)[: d["num_template_try"]]: cluster_index = possible_clusters[ind] - chan_sparsity = d["template_sparsity"][cluster_index, :] - template_sparse = templates[cluster_index, :, :][:, chan_sparsity] + sparsity_mask = templates.sparsity_mask[cluster_index, :] + template_sparse = templates[cluster_index][:, sparsity_mask] # find best shift @@ -262,9 +261,9 @@ def _tdc_find_spikes(traces, d, level=0): ## numba version numba_best_shift( traces, - templates[cluster_index, :, :], + templates[cluster_index], sample_index, - d["nbefore"], + templates.nbefore, possible_shifts, distances_shift, chan_sparsity, @@ -273,8 +272,8 @@ def _tdc_find_spikes(traces, d, level=0): shift = possible_shifts[ind_shift] sample_index = sample_index + shift - s0 = sample_index - d["nbefore"] - s1 = sample_index + d["nafter"] + s0 = sample_index - templates.nbefore + s1 = sample_index + templates.nafter wf_sparse = traces[s0:s1, chan_sparsity] # accept or not @@ -296,9 +295,9 @@ def _tdc_find_spikes(traces, d, level=0): amplitude = 1.0 # remove template - template = templates[cluster_index, :, :] - s0 = sample_index - d["nbefore"] - s1 = sample_index + d["nafter"] + template = templates[cluster_index] + s0 = sample_index - templates.nbefore + s1 = sample_index + templates.nafter traces[s0:s1, :] -= template * amplitude else: @@ -333,7 +332,7 @@ def numba_sparse_dist(wf, templates, union_channels, possible_clusters): if union_channels[chan_ind]: for s in range(width): v = wf[s, chan_ind] - t = templates[cluster_index, s, chan_ind] + t = templates[cluster_index][s, chan_ind] sum_dist += (v - t) ** 2 distances[i] = sum_dist return distances From f7d50187109c8a9197970172f3365a14d89c193d Mon Sep 17 00:00:00 2001 From: yger Date: Wed, 24 May 2023 15:02:17 +0200 Subject: [PATCH 10/17] Addition of a TemplatesDictionary object --- .../sortingcomponents/matching/circus.py | 174 +++++++++--------- .../sortingcomponents/matching/main.py | 111 +++++++++-- .../sortingcomponents/matching/naive.py | 3 +- .../sortingcomponents/matching/tdc.py | 11 +- .../sortingcomponents/matching/wobble.py | 11 +- .../tests/test_template_matching.py | 15 +- 6 files changed, 197 insertions(+), 128 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index 8ee518a85a..d6bd1f7067 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -199,8 +199,6 @@ class CircusOMPPeeler(BaseTemplateMatchingEngine): @classmethod def _prepare_templates(cls, d): - num_samples = d["num_samples"] - num_channels = d["num_channels"] templates = d["templates"] num_templates = len(templates) d["circus_templates"] = d["templates"].data.copy() @@ -247,7 +245,7 @@ def initialize_and_check_kwargs(cls, recording, kwargs): d["ignored_ids"] = np.array(d["ignored_ids"]) omp_min_sps = d["omp_min_sps"] - d["stop_criteria"] = omp_min_sps * np.sqrt(d["noise_levels"].sum() * d["num_samples"]) + d["stop_criteria"] = omp_min_sps * np.sqrt(d["noise_levels"].sum() * num_samples) return d @@ -408,7 +406,7 @@ def main_function(cls, traces, d): valid_indices = np.where(is_valid) num_spikes = len(valid_indices[0]) - spikes["sample_index"][:num_spikes] = valid_indices[1] + d["nbefore"] + spikes["sample_index"][:num_spikes] = valid_indices[1] + nbefore spikes["channel_index"][:num_spikes] = 0 spikes["cluster_index"][:num_spikes] = valid_indices[0] spikes["amplitude"][:num_spikes] = final_amplitudes[valid_indices[0], valid_indices[1]] @@ -478,13 +476,8 @@ class CircusPeeler(BaseTemplateMatchingEngine): @classmethod def _prepare_templates(cls, d): - num_samples = d["num_samples"] - num_channels = d["num_channels"] - num_templates = d["num_templates"] - max_amplitude = d["max_amplitude"] - min_amplitude = d["min_amplitude"] - use_sparse_matrix_threshold = d["use_sparse_matrix_threshold"] - d["circus_templates"] = d["templates"].copy() + num_templates = d["templates"].num_templates + d["circus_templates"] = d["templates"].data.copy() d["norms"] = np.zeros(num_templates, dtype=np.float32) for unit_ind in range(num_templates): @@ -497,8 +490,8 @@ def _prepare_templates(cls, d): def _compress_templates(cls, d): circus_templates = d.pop("circus_templates") num_templates = len(circus_templates) - num_samples = d["num_samples"] - num_channels = d["num_channels"] + num_samples = d["templates"].nsamples + num_channels = d["templates"].num_channels circus_templates = circus_templates.reshape(num_templates, -1) nnz = np.sum(circus_templates != 0) / (num_templates * num_samples * num_channels) @@ -506,7 +499,7 @@ def _compress_templates(cls, d): circus_templates = scipy.sparse.csr_matrix(circus_templates) d["is_dense"] = False else: - parameters["is_dense"] = True + d["is_dense"] = True d["circus_templates"] = circus_templates return d @@ -532,100 +525,98 @@ def _cost_function_mcc(cls, bounds, good, bad, delta_amplitude, alpha): ) return cost - @classmethod - def _optimize_amplitudes(cls, noise_snippets, d): - waveform_extractor = d["waveform_extractor"] - templates = d["circus_templates"] - num_templates = d["num_templates"] - max_amplitude = d["max_amplitude"] - min_amplitude = d["min_amplitude"] - alpha = 0.5 - norms = parameters["norms"] - all_units = list(waveform_extractor.sorting.unit_ids) - - parameters["amplitudes"] = np.zeros((num_templates, 2), dtype=np.float32) - noise = templates.dot(noise_snippets) / norms[:, np.newaxis] - - all_amps = {} - for count, unit_id in enumerate(all_units): - waveform = waveform_extractor.get_waveforms(unit_id) - snippets = waveform.reshape(waveform.shape[0], -1).T - amps = templates.dot(snippets) / norms[:, np.newaxis] - good = amps[count, :].flatten() - - sub_amps = amps[np.concatenate((np.arange(count), np.arange(count + 1, num_templates))), :] - bad = sub_amps[sub_amps >= good] - bad = np.concatenate((bad, noise[count])) - cost_kwargs = [good, bad, max_amplitude - min_amplitude, alpha] - cost_bounds = [(min_amplitude, 1), (1, max_amplitude)] - res = scipy.optimize.differential_evolution(cls._cost_function_mcc, bounds=cost_bounds, args=cost_kwargs) - parameters["amplitudes"][count] = res.x - - return d + # @classmethod + # def _optimize_amplitudes(cls, noise_snippets, d): + # waveform_extractor = d["waveform_extractor"] + # templates = d["circus_templates"] + # num_templates = d["templates"].num_templates + # max_amplitude = d["max_amplitude"] + # min_amplitude = d["min_amplitude"] + # alpha = 0.5 + # norms = d["norms"] + # all_units = list(waveform_extractor.sorting.unit_ids) + + # d["amplitudes"] = np.zeros((num_templates, 2), dtype=np.float32) + # noise = templates.dot(noise_snippets) / norms[:, np.newaxis] + + # all_amps = {} + # for count, unit_id in enumerate(all_units): + # waveform = waveform_extractor.get_waveforms(unit_id) + # snippets = waveform.reshape(waveform.shape[0], -1).T + # amps = templates.dot(snippets) / norms[:, np.newaxis] + # good = amps[count, :].flatten() + + # sub_amps = amps[np.concatenate((np.arange(count), np.arange(count + 1, num_templates))), :] + # bad = sub_amps[sub_amps >= good] + # bad = np.concatenate((bad, noise[count])) + # cost_kwargs = [good, bad, max_amplitude - min_amplitude, alpha] + # cost_bounds = [(min_amplitude, 1), (1, max_amplitude)] + # res = scipy.optimize.differential_evolution(cls._cost_function_mcc, bounds=cost_bounds, args=cost_kwargs) + # d["amplitudes"][count] = res.x + + # return d @classmethod def initialize_and_check_kwargs(cls, recording, kwargs): assert HAVE_SKLEARN, "CircusPeeler needs sklearn to work" default_parameters = cls._default_params.copy() default_parameters.update(kwargs) + d = default_parameters for v in ["use_sparse_matrix_threshold"]: assert (d[v] >= 0) and (d[v] <= 1), f"{v} should be in [0, 1]" - d["num_channels"] = d["waveform_extractor"].recording.get_num_channels() - d["num_samples"] = d["waveform_extractor"].nsamples - d["num_templates"] = len(d["templates"]) - - default_parameters["num_channels"] = default_parameters["waveform_extractor"].recording.get_num_channels() - default_parameters["num_samples"] = default_parameters["waveform_extractor"].nsamples - default_parameters["num_templates"] = len(default_parameters["waveform_extractor"].sorting.unit_ids) + num_channels = d["templates"].num_channels + num_samples = d["templates"].nsamples + num_templates = d["templates"].num_templates - if default_parameters["noise_levels"] is None: + if d["noise_levels"] is None: print("CircusPeeler : noise should be computed outside") - default_parameters["noise_levels"] = get_noise_levels( - recording, **default_parameters["random_chunk_kwargs"], return_scaled=False + d["noise_levels"] = get_noise_levels( + recording, **d["random_chunk_kwargs"], return_scaled=False ) + d["abs_threholds"] = d["noise_levels"] * d["detect_threshold"] + d = cls._prepare_templates(d) - d["overlaps"] = compute_overlaps(d["circus_templates"], d["num_samples"], d["num_channels"]) + d["overlaps"] = compute_overlaps(d["circus_templates"], num_samples, num_channels) d = cls._compress_templates(d) - default_parameters = cls._prepare_templates(default_parameters) - default_parameters = cls._prepare_overlaps(default_parameters) - - default_parameters["exclude_sweep_size"] = int( - default_parameters["exclude_sweep_ms"] * recording.get_sampling_frequency() / 1000.0 + d["exclude_sweep_size"] = int( + d["exclude_sweep_ms"] * recording.get_sampling_frequency() / 1000.0 ) - default_parameters["nbefore"] = default_parameters["waveform_extractor"].nbefore - default_parameters["nafter"] = default_parameters["waveform_extractor"].nafter - default_parameters["patch_sizes"] = ( - default_parameters["waveform_extractor"].nsamples, - default_parameters["num_channels"], + d["patch_sizes"] = ( + d["templates"].nsamples, + d["templates"].num_channels, ) - default_parameters["sym_patch"] = default_parameters["nbefore"] == default_parameters["nafter"] - default_parameters["jitter"] = int( - default_parameters["jitter_ms"] * recording.get_sampling_frequency() / 1000.0 + d["sym_patch"] = d["templates"].nbefore == d["templates"].nafter + d["jitter"] = int( + d["jitter_ms"] * recording.get_sampling_frequency() / 1000.0 ) - num_segments = recording.get_num_segments() - if default_parameters["waveform_extractor"]._params["max_spikes_per_unit"] is None: - num_snippets = 1000 - else: - num_snippets = 2 * default_parameters["waveform_extractor"]._params["max_spikes_per_unit"] - - num_chunks = num_snippets // num_segments - noise_snippets = get_random_data_chunks( - recording, num_chunks_per_segment=num_chunks, chunk_size=default_parameters["num_samples"], seed=42 - ) - noise_snippets = ( - noise_snippets.reshape(num_chunks, default_parameters["num_samples"], default_parameters["num_channels"]) - .reshape(num_chunks, -1) - .T - ) - parameters = cls._optimize_amplitudes(noise_snippets, default_parameters) + d['amplitudes'] = np.zeros((num_templates, 2)) + d['amplitudes'][:, 0] = d["min_amplitude"] + d['amplitudes'][:, 1] = d["max_amplitude"] + + # num_segments = recording.get_num_segments() + # if d["waveform_extractor"]._params["max_spikes_per_unit"] is None: + # num_snippets = 1000 + # else: + # num_snippets = 2 * d["waveform_extractor"]._params["max_spikes_per_unit"] + + # num_chunks = num_snippets // num_segments + # noise_snippets = get_random_data_chunks( + # recording, num_chunks_per_segment=num_chunks, chunk_size=d["num_samples"], seed=42 + # ) + # noise_snippets = ( + # noise_snippets.reshape(num_chunks, d["num_samples"], d["num_channels"]) + # .reshape(num_chunks, -1) + # .T + # ) + # d = cls._optimize_amplitudes(noise_snippets, d) - return parameters + return d @classmethod def unserialize_in_worker(cls, kwargs): @@ -633,7 +624,8 @@ def unserialize_in_worker(cls, kwargs): @classmethod def get_margin(cls, recording, kwargs): - margin = 2 * max(kwargs["nbefore"], kwargs["nafter"]) + templates = kwargs["templates"] + margin = 2 * max(templates.nbefore, templates.nafter) return margin @classmethod @@ -642,14 +634,16 @@ def main_function(cls, traces, d): abs_threholds = d["abs_threholds"] exclude_sweep_size = d["exclude_sweep_size"] templates = d["circus_templates"] - num_templates = d["num_templates"] - num_channels = d["num_channels"] + num_templates = d["templates"].num_templates + num_channels = d["templates"].num_channels overlaps = d["overlaps"] margin = d["margin"] norms = d["norms"] jitter = d["jitter"] patch_sizes = d["patch_sizes"] - num_samples = d["nafter"] + d["nbefore"] + nbefore = d["templates"].nbefore + nafter = d["templates"].nafter + num_samples = nbefore + nafter neighbor_window = num_samples - 1 amplitudes = d["amplitudes"] sym_patch = d["sym_patch"] @@ -678,7 +672,7 @@ def main_function(cls, traces, d): peak_sample_ind += margin // 2 else: peak_sample_ind += margin // 2 - snippet_window = np.arange(-d["nbefore"], d["nafter"]) + snippet_window = np.arange(-nbefore, nafter) snippets = traces[peak_sample_ind[:, np.newaxis] + snippet_window] if num_peaks > 0: diff --git a/src/spikeinterface/sortingcomponents/matching/main.py b/src/spikeinterface/sortingcomponents/matching/main.py index e8f467cfb8..7152f5aadd 100644 --- a/src/spikeinterface/sortingcomponents/matching/main.py +++ b/src/spikeinterface/sortingcomponents/matching/main.py @@ -17,15 +17,15 @@ def __init__( self, data, unit_ids, - ms_before, - ms_after, + nbefore, + nafter, sparsity_mask=None ): - self.data = data.copy() + self.data = data.copy().astype(np.float32, casting="safe") self.unit_ids = unit_ids - self.ms_before = ms_before - self.ms_after = ms_after + self.nbefore = nbefore + self.nafter = nafter if sparsity_mask is None: self.sparsity_mask = np.sum(data, axis=(1)) == 0 else: @@ -46,6 +46,10 @@ def __getslice__(self, start, stop): def shape(self): return self.data.shape + @property + def num_templates(self): + return self.data.shape[0] + @property def num_channels(self): return self.data.shape[2] @@ -57,9 +61,90 @@ def nsamples(self): def __len__(self): return len(self.data) - def get_template_extremum_channel(self, peak_sign='neg', outputs="index"): - assert peak_sign in ['neg', 'pos'], "peak_sign should be in ['neg', 'pos']" - return + def get_amplitudes( + self, + peak_sign: str = "neg", + mode: str = "extremum" + ): + """ + Get amplitude per channel for each unit. + + Parameters + ---------- + waveform_extractor: WaveformExtractor + The waveform extractor + peak_sign: str + Sign of the template to compute best channels ('neg', 'pos', 'both') + mode: str + 'extremum': max or min + 'at_index': take value at spike index + + Returns + ------- + peak_values: dict + Dictionary with unit ids as keys and template amplitudes as values + """ + assert peak_sign in ("both", "neg", "pos") + assert mode in ("extremum", "at_index") + peak_values = {} + for unit_ind, unit_id in enumerate(self.unit_ids): + template = self.data[unit_ind] + + if mode == "extremum": + if peak_sign == "both": + values = np.max(np.abs(template), axis=0) + elif peak_sign == "neg": + values = -np.min(template, axis=0) + elif peak_sign == "pos": + values = np.max(template, axis=0) + elif mode == "at_index": + if peak_sign == "both": + values = np.abs(template[self.nbefore, :]) + elif peak_sign == "neg": + values = -template[self.before, :] + elif peak_sign == "pos": + values = template[self.before, :] + + peak_values[unit_id] = values + + return peak_values + + def get_extremum_channel( + self, + peak_sign: str = "neg", + mode: str = "extremum", + ): + + """ + Compute the channel with the extremum peak for each unit. + + Parameters + ---------- + waveform_extractor: WaveformExtractor + The waveform extractor + peak_sign: str + Sign of the template to compute best channels ('neg', 'pos', 'both') + mode: str + 'extremum': max or min + 'at_index': take value at spike index + + Returns + ------- + extremum_channels: dict + Dictionary with unit ids as keys and extremum channels (id or index based on 'outputs') + as values + """ + + assert peak_sign in ("both", "neg", "pos") + assert mode in ("extremum", "at_index") + + peak_values = self.get_amplitudes(peak_sign=peak_sign, mode=mode) + extremum_channels_index = {} + for unit_id in self.unit_ids: + max_ind = np.argmax(peak_values[unit_id]) + extremum_channels_index[unit_id] = max_ind + + return extremum_channels_index def create_templates_from_waveform_extractor(waveform_extractor, mode='median', sparsity=None): @@ -69,11 +154,11 @@ def create_templates_from_waveform_extractor(waveform_extractor, mode='median', else: sparsity_mask = None - data = waveform_extractor.get_all_templates(mode) + data = waveform_extractor.get_all_templates(mode=mode) unit_ids = waveform_extractor.unit_ids - ms_before = waveform_extractor.ms_before - ms_after = waveform_extractor.ms_after - return TemplatesDictionary(data, waveform_extractor.unit_ids, ms_before, ms_after, sparsity_mask) + nbefore = waveform_extractor.nbefore + nafter = waveform_extractor.nafter + return TemplatesDictionary(data, waveform_extractor.unit_ids, nbefore, nafter, sparsity_mask) def find_spikes_from_templates( @@ -129,7 +214,7 @@ def find_spikes_from_templates( # initialize the templates method_kwargs = method_class.initialize_and_sparsify_templates( - method_kwargs, waveform_extractor, sparsity, templates + method_kwargs, waveform_extractor, sparsity, templates_dictionary ) # initialize diff --git a/src/spikeinterface/sortingcomponents/matching/naive.py b/src/spikeinterface/sortingcomponents/matching/naive.py index 459a1c1245..c8e56696e3 100644 --- a/src/spikeinterface/sortingcomponents/matching/naive.py +++ b/src/spikeinterface/sortingcomponents/matching/naive.py @@ -56,7 +56,8 @@ def initialize_and_check_kwargs(cls, recording, kwargs): channel_distance = get_channel_distances(recording) d["neighbours_mask"] = channel_distance < d["local_radius_um"] - d["exclude_sweep_size"] = int(d["exclude_sweep_ms"] * d["sampling_frequency"] / 1000.0) + fs = recording.sampling_frequency + d["exclude_sweep_size"] = int(d["exclude_sweep_ms"] * fs / 1000.0) return d diff --git a/src/spikeinterface/sortingcomponents/matching/tdc.py b/src/spikeinterface/sortingcomponents/matching/tdc.py index 0569b83f56..02a9eeba3b 100644 --- a/src/spikeinterface/sortingcomponents/matching/tdc.py +++ b/src/spikeinterface/sortingcomponents/matching/tdc.py @@ -3,7 +3,6 @@ from spikeinterface.core import ( get_noise_levels, get_channel_distances, - get_template_extremum_channel, ) from spikeinterface.sortingcomponents.peak_detection import DetectPeakLocallyExclusive @@ -96,7 +95,7 @@ def initialize_and_check_kwargs(cls, recording, kwargs): channel_distance = get_channel_distances(recording) d["neighbours_mask"] = channel_distance < d["local_radius_um"] - extremum_channel = templates.get_template_extremum_channel(peak_sign=d["peak_sign"], outputs="index") + extremum_channel = templates.get_extremum_channel(peak_sign=d["peak_sign"]) # as numpy vector extremum_channel = np.array([extremum_channel[unit_id] for unit_id in unit_ids], dtype="int64") d["extremum_channel"] = extremum_channel @@ -266,7 +265,7 @@ def _tdc_find_spikes(traces, d, level=0): templates.nbefore, possible_shifts, distances_shift, - chan_sparsity, + sparsity_mask, ) ind_shift = np.argmin(distances_shift) shift = possible_shifts[ind_shift] @@ -274,7 +273,7 @@ def _tdc_find_spikes(traces, d, level=0): sample_index = sample_index + shift s0 = sample_index - templates.nbefore s1 = sample_index + templates.nafter - wf_sparse = traces[s0:s1, chan_sparsity] + wf_sparse = traces[s0:s1, sparsity_mask] # accept or not @@ -338,7 +337,7 @@ def numba_sparse_dist(wf, templates, union_channels, possible_clusters): return distances @jit(nopython=True) - def numba_best_shift(traces, template, sample_index, nbefore, possible_shifts, distances_shift, chan_sparsity): + def numba_best_shift(traces, template, sample_index, nbefore, possible_shifts, distances_shift, sparsity_mask): """ numba implementation to compute several sample shift before template substraction """ @@ -348,7 +347,7 @@ def numba_best_shift(traces, template, sample_index, nbefore, possible_shifts, d shift = possible_shifts[i] sum_dist = 0.0 for chan_ind in range(num_chan): - if chan_sparsity[chan_ind]: + if sparsity_mask[chan_ind]: for s in range(width): v = traces[sample_index - nbefore + s + shift, chan_ind] t = template[s, chan_ind] diff --git a/src/spikeinterface/sortingcomponents/matching/wobble.py b/src/spikeinterface/sortingcomponents/matching/wobble.py index 9e7abbb352..c944368741 100644 --- a/src/spikeinterface/sortingcomponents/matching/wobble.py +++ b/src/spikeinterface/sortingcomponents/matching/wobble.py @@ -228,7 +228,7 @@ def from_parameters_and_templates(cls, params, templates): sparsity : Sparsity Dataclass object for aggregating channel sparsity variables together. """ - visible_channels = np.ptp(templates, axis=1) > params.visibility_threshold + visible_channels = np.ptp(templates.data, axis=1) > params.visibility_threshold unit_overlap = np.sum( np.logical_and(visible_channels[:, np.newaxis, :], visible_channels[np.newaxis, :, :]), axis=2 ) @@ -337,7 +337,6 @@ def initialize_and_check_kwargs(cls, recording, kwargs): d.update(kwargs) parameters = d.get("parameters", {}) templates = d["templates"] - templates = templates.astype(np.float32, casting="safe") # Aggregate useful parameters/variables for handy access in downstream functions params = WobbleParameters(**parameters) @@ -347,14 +346,14 @@ def initialize_and_check_kwargs(cls, recording, kwargs): ) # TODO: replace with spikeinterface sparsity # Perform initial computations on templates necessary for computing the objective - sparse_templates = np.where(sparsity.visible_channels[:, np.newaxis, :], templates, 0) + sparse_templates = np.where(sparsity.visible_channels[:, np.newaxis, :], templates.data, 0) temporal, singular, spatial = compress_templates(sparse_templates, params.approx_rank) temporal_jittered = upsample_and_jitter(temporal, params.jitter_factor, template_meta.num_samples) compressed_templates = (temporal, singular, spatial, temporal_jittered) pairwise_convolution = convolve_templates( compressed_templates, params.jitter_factor, params.approx_rank, template_meta.jittered_indices, sparsity ) - norm_squared = compute_template_norm(sparsity.visible_channels, templates) + norm_squared = compute_template_norm(sparsity.visible_channels, templates.data) template_data = TemplateData( compressed_templates=compressed_templates, pairwise_convolution=pairwise_convolution, @@ -372,8 +371,6 @@ def initialize_and_check_kwargs(cls, recording, kwargs): def serialize_method_kwargs(cls, kwargs): # This function does nothing without a waveform extractor -- candidate for refactor kwargs = dict(kwargs) - # remove waveform_extractor - kwargs.pop("waveform_extractor") return kwargs @classmethod @@ -419,7 +416,7 @@ def main_function(cls, traces, method_kwargs): Resulting spike train. """ # Unpack method_kwargs - nbefore, nafter = method_kwargs["nbefore"], method_kwargs["nafter"] + nbefore, nafter = method_kwargs['templates'].nbefore, method_kwargs['templates'].nafter template_meta = method_kwargs["template_meta"] params = method_kwargs["params"] sparsity = method_kwargs["sparsity"] diff --git a/src/spikeinterface/sortingcomponents/tests/test_template_matching.py b/src/spikeinterface/sortingcomponents/tests/test_template_matching.py index 035b08ba45..38aafeab33 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_template_matching.py +++ b/src/spikeinterface/sortingcomponents/tests/test_template_matching.py @@ -48,23 +48,16 @@ def test_find_spikes_from_templates(method, waveform_extractor): waveform = waveform_extractor.get_waveforms("#0") num_waveforms, _, _ = waveform.shape assert num_waveforms != 0 - method_kwargs_all = {"waveform_extractor": waveform_extractor, "noise_levels": get_noise_levels(recording)} - method_kwargs = {} - method_kwargs["wobble"] = { - "templates": waveform_extractor.get_all_templates(), - "nbefore": waveform_extractor.nbefore, - "nafter": waveform_extractor.nafter, - } + method_kwargs_all = {"noise_levels": get_noise_levels(recording)} sampling_frequency = recording.get_sampling_frequency() result = {} - - method_kwargs_ = method_kwargs.get(method, {}) - method_kwargs_.update(method_kwargs_all) spikes = find_spikes_from_templates( - recording, method=method, method_kwargs=method_kwargs_, n_jobs=2, chunk_size=1000, progress_bar=True + recording, waveform_extractor, method=method, method_kwargs=method_kwargs_all, n_jobs=2, chunk_size=1000, progress_bar=True ) + if method == 'circus': + method_kwargs_all['waveform_extractor'] = waveform_extractor result[method] = NumpySorting.from_times_labels(spikes["sample_index"], spikes["cluster_index"], sampling_frequency) From 66fa4e7f4809ff11a30f50b08f5511e2521b31a4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 24 May 2023 12:58:00 +0000 Subject: [PATCH 11/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sortingcomponents/matching/circus.py | 22 +++++-------- .../sortingcomponents/matching/main.py | 33 +++++-------------- .../sortingcomponents/matching/naive.py | 2 +- .../sortingcomponents/matching/tdc.py | 2 +- .../sortingcomponents/matching/wobble.py | 2 +- .../tests/test_template_matching.py | 12 +++++-- 6 files changed, 29 insertions(+), 44 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index d6bd1f7067..7fc68a9c1c 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -255,7 +255,7 @@ def unserialize_in_worker(cls, kwargs): @classmethod def get_margin(cls, recording, kwargs): - templates = kwargs['templates'] + templates = kwargs["templates"] margin = 2 * max(templates.nbefore, templates.nafter) return margin @@ -268,7 +268,7 @@ def main_function(cls, traces, d): overlaps = d["overlaps"] norms = d["norms"] nbefore = d["templates"].nbefore - nafter =d["templates"].nafter + nafter = d["templates"].nafter omp_tol = np.finfo(np.float32).eps neighbor_window = num_samples - 1 min_amplitude, max_amplitude = d["amplitudes"] @@ -572,9 +572,7 @@ def initialize_and_check_kwargs(cls, recording, kwargs): if d["noise_levels"] is None: print("CircusPeeler : noise should be computed outside") - d["noise_levels"] = get_noise_levels( - recording, **d["random_chunk_kwargs"], return_scaled=False - ) + d["noise_levels"] = get_noise_levels(recording, **d["random_chunk_kwargs"], return_scaled=False) d["abs_threholds"] = d["noise_levels"] * d["detect_threshold"] @@ -582,22 +580,18 @@ def initialize_and_check_kwargs(cls, recording, kwargs): d["overlaps"] = compute_overlaps(d["circus_templates"], num_samples, num_channels) d = cls._compress_templates(d) - d["exclude_sweep_size"] = int( - d["exclude_sweep_ms"] * recording.get_sampling_frequency() / 1000.0 - ) + d["exclude_sweep_size"] = int(d["exclude_sweep_ms"] * recording.get_sampling_frequency() / 1000.0) d["patch_sizes"] = ( d["templates"].nsamples, d["templates"].num_channels, ) d["sym_patch"] = d["templates"].nbefore == d["templates"].nafter - d["jitter"] = int( - d["jitter_ms"] * recording.get_sampling_frequency() / 1000.0 - ) + d["jitter"] = int(d["jitter_ms"] * recording.get_sampling_frequency() / 1000.0) - d['amplitudes'] = np.zeros((num_templates, 2)) - d['amplitudes'][:, 0] = d["min_amplitude"] - d['amplitudes'][:, 1] = d["max_amplitude"] + d["amplitudes"] = np.zeros((num_templates, 2)) + d["amplitudes"][:, 0] = d["min_amplitude"] + d["amplitudes"][:, 1] = d["max_amplitude"] # num_segments = recording.get_num_segments() # if d["waveform_extractor"]._params["max_spikes_per_unit"] is None: diff --git a/src/spikeinterface/sortingcomponents/matching/main.py b/src/spikeinterface/sortingcomponents/matching/main.py index 7152f5aadd..f8cfa000b4 100644 --- a/src/spikeinterface/sortingcomponents/matching/main.py +++ b/src/spikeinterface/sortingcomponents/matching/main.py @@ -12,16 +12,7 @@ class TemplatesDictionary(object): - - def __init__( - self, - data, - unit_ids, - nbefore, - nafter, - sparsity_mask=None - ): - + def __init__(self, data, unit_ids, nbefore, nafter, sparsity_mask=None): self.data = data.copy().astype(np.float32, casting="safe") self.unit_ids = unit_ids self.nbefore = nbefore @@ -29,9 +20,9 @@ def __init__( if sparsity_mask is None: self.sparsity_mask = np.sum(data, axis=(1)) == 0 else: - assert sparsity_mask.shape == (data.shape[0], data.shape[2]), 'sparsity_mask has the wrong shape' + assert sparsity_mask.shape == (data.shape[0], data.shape[2]), "sparsity_mask has the wrong shape" self.sparsity_mask = sparsity_mask - + for i in range(len(self.data)): active_channels = self.sparsity_mask[i] self.data[i][:, ~active_channels] = 0 @@ -61,11 +52,7 @@ def nsamples(self): def __len__(self): return len(self.data) - def get_amplitudes( - self, - peak_sign: str = "neg", - mode: str = "extremum" - ): + def get_amplitudes(self, peak_sign: str = "neg", mode: str = "extremum"): """ Get amplitude per channel for each unit. @@ -110,11 +97,10 @@ def get_amplitudes( return peak_values def get_extremum_channel( - self, - peak_sign: str = "neg", - mode: str = "extremum", - ): - + self, + peak_sign: str = "neg", + mode: str = "extremum", + ): """ Compute the channel with the extremum peak for each unit. @@ -147,8 +133,7 @@ def get_extremum_channel( return extremum_channels_index -def create_templates_from_waveform_extractor(waveform_extractor, mode='median', sparsity=None): - +def create_templates_from_waveform_extractor(waveform_extractor, mode="median", sparsity=None): if sparsity is not None and not waveform_extractor.is_sparse(): sparsity_mask = compute_sparsity(waveform_extractor, **sparsity) else: diff --git a/src/spikeinterface/sortingcomponents/matching/naive.py b/src/spikeinterface/sortingcomponents/matching/naive.py index c8e56696e3..7b6e683a34 100644 --- a/src/spikeinterface/sortingcomponents/matching/naive.py +++ b/src/spikeinterface/sortingcomponents/matching/naive.py @@ -63,7 +63,7 @@ def initialize_and_check_kwargs(cls, recording, kwargs): @classmethod def get_margin(cls, recording, kwargs): - templates = kwargs['templates'] + templates = kwargs["templates"] margin = max(templates.nbefore, templates.nafter) return margin diff --git a/src/spikeinterface/sortingcomponents/matching/tdc.py b/src/spikeinterface/sortingcomponents/matching/tdc.py index 02a9eeba3b..b1a225c81e 100644 --- a/src/spikeinterface/sortingcomponents/matching/tdc.py +++ b/src/spikeinterface/sortingcomponents/matching/tdc.py @@ -152,7 +152,7 @@ def unserialize_in_worker(cls, kwargs): @classmethod def get_margin(cls, recording, kwargs): - templates = kwargs['templates'] + templates = kwargs["templates"] margin = 2 * (templates.nbefore + templates.nafter) return margin diff --git a/src/spikeinterface/sortingcomponents/matching/wobble.py b/src/spikeinterface/sortingcomponents/matching/wobble.py index c944368741..e9f3c46cde 100644 --- a/src/spikeinterface/sortingcomponents/matching/wobble.py +++ b/src/spikeinterface/sortingcomponents/matching/wobble.py @@ -416,7 +416,7 @@ def main_function(cls, traces, method_kwargs): Resulting spike train. """ # Unpack method_kwargs - nbefore, nafter = method_kwargs['templates'].nbefore, method_kwargs['templates'].nafter + nbefore, nafter = method_kwargs["templates"].nbefore, method_kwargs["templates"].nafter template_meta = method_kwargs["template_meta"] params = method_kwargs["params"] sparsity = method_kwargs["sparsity"] diff --git a/src/spikeinterface/sortingcomponents/tests/test_template_matching.py b/src/spikeinterface/sortingcomponents/tests/test_template_matching.py index 38aafeab33..1298c11ce5 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_template_matching.py +++ b/src/spikeinterface/sortingcomponents/tests/test_template_matching.py @@ -54,10 +54,16 @@ def test_find_spikes_from_templates(method, waveform_extractor): result = {} spikes = find_spikes_from_templates( - recording, waveform_extractor, method=method, method_kwargs=method_kwargs_all, n_jobs=2, chunk_size=1000, progress_bar=True + recording, + waveform_extractor, + method=method, + method_kwargs=method_kwargs_all, + n_jobs=2, + chunk_size=1000, + progress_bar=True, ) - if method == 'circus': - method_kwargs_all['waveform_extractor'] = waveform_extractor + if method == "circus": + method_kwargs_all["waveform_extractor"] = waveform_extractor result[method] = NumpySorting.from_times_labels(spikes["sample_index"], spikes["cluster_index"], sampling_frequency) From 2f17e14038d3dbf6bcce532407b51aebb9383269 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 24 May 2023 14:23:52 +0000 Subject: [PATCH 12/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sortingcomponents/matching/main.py | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/matching/main.py b/src/spikeinterface/sortingcomponents/matching/main.py index 753fce328a..91655f447e 100644 --- a/src/spikeinterface/sortingcomponents/matching/main.py +++ b/src/spikeinterface/sortingcomponents/matching/main.py @@ -34,15 +34,7 @@ class TemplatesDictionary(object): The TemplatesDictionary object """ - def __init__( - self, - data : np.array, - unit_ids : list, - nbefore : int, - nafter : int, - sparsity_mask=None - ) -> None: - + def __init__(self, data: np.array, unit_ids: list, nbefore: int, nafter: int, sparsity_mask=None) -> None: self.data = data.copy().astype(np.float32, casting="safe") self.unit_ids = unit_ids self.nbefore = nbefore @@ -85,11 +77,7 @@ def nsamples(self): def __len__(self): return len(self.data) - def get_amplitudes( - self, - peak_sign: str = "neg", - mode: str = "extremum" - ): + def get_amplitudes(self, peak_sign: str = "neg", mode: str = "extremum"): """ Get amplitude per channel for each unit. From 9cb1c402588e76aa18458aebfe082a3637011f81 Mon Sep 17 00:00:00 2001 From: yger Date: Wed, 24 May 2023 16:25:48 +0200 Subject: [PATCH 13/17] Formatting --- .../sortingcomponents/matching/circus.py | 2 +- .../sortingcomponents/matching/main.py | 41 ++++++++++++++++++- 2 files changed, 40 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index 7fc68a9c1c..042cebd19f 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -637,7 +637,7 @@ def main_function(cls, traces, d): patch_sizes = d["patch_sizes"] nbefore = d["templates"].nbefore nafter = d["templates"].nafter - num_samples = nbefore + nafter + num_samples = d["templates"].nsamples neighbor_window = num_samples - 1 amplitudes = d["amplitudes"] sym_patch = d["sym_patch"] diff --git a/src/spikeinterface/sortingcomponents/matching/main.py b/src/spikeinterface/sortingcomponents/matching/main.py index f8cfa000b4..753fce328a 100644 --- a/src/spikeinterface/sortingcomponents/matching/main.py +++ b/src/spikeinterface/sortingcomponents/matching/main.py @@ -12,11 +12,44 @@ class TemplatesDictionary(object): - def __init__(self, data, unit_ids, nbefore, nafter, sparsity_mask=None): + """ + Class to extract handle the templates in order to work with the matching + engines + + Parameters + ---------- + data: array + The template matrix, as a numpy array of size (n_templates, n_samples, n_channels) + unit_ids: list + The list of unit_ids + nbefore: int + The number of samples before the peaks of the templates + nafter: int + The number of samples after the peaks of the templates + sparsity_mask: None or array + If not None, an array of size (n_templates, n_channels) to set some sparsity mask + Returns + ------- + templates: TemplatesDictionary + The TemplatesDictionary object + """ + + def __init__( + self, + data : np.array, + unit_ids : list, + nbefore : int, + nafter : int, + sparsity_mask=None + ) -> None: + self.data = data.copy().astype(np.float32, casting="safe") self.unit_ids = unit_ids self.nbefore = nbefore self.nafter = nafter + + assert self.nbefore + self.nafter == data.shape[1] + if sparsity_mask is None: self.sparsity_mask = np.sum(data, axis=(1)) == 0 else: @@ -52,7 +85,11 @@ def nsamples(self): def __len__(self): return len(self.data) - def get_amplitudes(self, peak_sign: str = "neg", mode: str = "extremum"): + def get_amplitudes( + self, + peak_sign: str = "neg", + mode: str = "extremum" + ): """ Get amplitude per channel for each unit. From 8671331842e1f18dfa46cef2d51199f44ee55843 Mon Sep 17 00:00:00 2001 From: yger Date: Wed, 24 May 2023 16:28:37 +0200 Subject: [PATCH 14/17] Simplication --- .../sortingcomponents/matching/main.py | 26 +++++++------------ 1 file changed, 9 insertions(+), 17 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/matching/main.py b/src/spikeinterface/sortingcomponents/matching/main.py index 753fce328a..a90d8ab145 100644 --- a/src/spikeinterface/sortingcomponents/matching/main.py +++ b/src/spikeinterface/sortingcomponents/matching/main.py @@ -172,7 +172,7 @@ def get_extremum_channel( def create_templates_from_waveform_extractor(waveform_extractor, mode="median", sparsity=None): if sparsity is not None and not waveform_extractor.is_sparse(): - sparsity_mask = compute_sparsity(waveform_extractor, **sparsity) + sparsity_mask = compute_sparsity(waveform_extractor, **sparsity).mask else: sparsity_mask = None @@ -234,10 +234,14 @@ def find_spikes_from_templates( method_class = matching_methods[method] - # initialize the templates - method_kwargs = method_class.initialize_and_sparsify_templates( - method_kwargs, waveform_extractor, sparsity, templates_dictionary - ) + assert isinstance(waveform_extractor, WaveformExtractor) + + if templates_dictionary is None: + templates_dictionary = create_templates_from_waveform_extractor(waveform_extractor, sparsity=sparsity) + + assert isinstance(templates_dictionary, TemplatesDictionary) + + method_kwargs["templates"] = templates_dictionary # initialize method_kwargs = method_class.initialize_and_check_kwargs(recording, method_kwargs) @@ -315,18 +319,6 @@ def _find_spikes_chunk(segment_index, start_frame, end_frame, worker_ctx): # generic class for template engine class BaseTemplateMatchingEngine: - @classmethod - def initialize_and_sparsify_templates(cls, kwargs, waveform_extractor, templates_dictionary, sparsity): - assert isinstance(waveform_extractor, WaveformExtractor) - - if templates_dictionary is not None: - templates_dictionary = create_templates_from_waveform_extractor(waveform_extractor, sparsity=sparsity) - - assert isinstance(templates_dictionary, TemplatesDictionary) - - kwargs["templates"] = templates_dictionary - - return kwargs @classmethod def initialize_and_check_kwargs(cls, recording, kwargs): From 4e196b969276d45349da06644042c2c2e2b068fd Mon Sep 17 00:00:00 2001 From: yger Date: Wed, 24 May 2023 16:30:43 +0200 Subject: [PATCH 15/17] Simplication --- src/spikeinterface/sortingcomponents/matching/main.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/matching/main.py b/src/spikeinterface/sortingcomponents/matching/main.py index 157d212395..32f86eb092 100644 --- a/src/spikeinterface/sortingcomponents/matching/main.py +++ b/src/spikeinterface/sortingcomponents/matching/main.py @@ -222,9 +222,10 @@ def find_spikes_from_templates( method_class = matching_methods[method] - assert isinstance(waveform_extractor, WaveformExtractor) + assert (waveform_extractor is not None) or (templates_dictionary is not None) if templates_dictionary is None: + assert isinstance(waveform_extractor, WaveformExtractor) templates_dictionary = create_templates_from_waveform_extractor(waveform_extractor, sparsity=sparsity) assert isinstance(templates_dictionary, TemplatesDictionary) From fee02b0250014bf0e4e14083f34e50a3f902aacb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 24 May 2023 14:28:31 +0000 Subject: [PATCH 16/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sortingcomponents/matching/main.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/matching/main.py b/src/spikeinterface/sortingcomponents/matching/main.py index 32f86eb092..2d1ef85cfa 100644 --- a/src/spikeinterface/sortingcomponents/matching/main.py +++ b/src/spikeinterface/sortingcomponents/matching/main.py @@ -308,7 +308,6 @@ def _find_spikes_chunk(segment_index, start_frame, end_frame, worker_ctx): # generic class for template engine class BaseTemplateMatchingEngine: - @classmethod def initialize_and_check_kwargs(cls, recording, kwargs): """This function runs before loops""" From ef35672cdfba1f339831fe40651fe7e98e907cfc Mon Sep 17 00:00:00 2001 From: yger Date: Thu, 25 May 2023 10:42:21 +0200 Subject: [PATCH 17/17] Forgot noise levels for omp --- src/spikeinterface/sortingcomponents/matching/circus.py | 4 ++++ src/spikeinterface/sortingcomponents/matching/main.py | 3 +-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index 042cebd19f..f9c6660331 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -235,6 +235,10 @@ def initialize_and_check_kwargs(cls, recording, kwargs): num_samples = d["templates"].nsamples d["sampling_frequency"] = recording.get_sampling_frequency() + if d["noise_levels"] is None: + print("CircusPeeler : noise should be computed outside") + d["noise_levels"] = get_noise_levels(recording, **d["random_chunk_kwargs"], return_scaled=False) + d = cls._prepare_templates(d) if d["overlaps"] is None: diff --git a/src/spikeinterface/sortingcomponents/matching/main.py b/src/spikeinterface/sortingcomponents/matching/main.py index 2d1ef85cfa..d0336621b4 100644 --- a/src/spikeinterface/sortingcomponents/matching/main.py +++ b/src/spikeinterface/sortingcomponents/matching/main.py @@ -49,8 +49,7 @@ def __init__(self, data: np.array, unit_ids: list, nbefore: int, nafter: int, sp self.sparsity_mask = sparsity_mask for i in range(len(self.data)): - active_channels = self.sparsity_mask[i] - self.data[i][:, ~active_channels] = 0 + self.data[i][:, ~self.sparsity_mask[i]] = 0 def __getitem__(self, template_id): return self.data[template_id]