diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 3ebeafcfec..5a0148c5c4 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -7,6 +7,9 @@ from spikeinterface.core.waveform_extractor import WaveformExtractor, BaseWaveformExtractorExtension +# DEBUG = True + + class AmplitudeScalingsCalculator(BaseWaveformExtractorExtension): """ Computes amplitude scalings from WaveformExtractor. @@ -21,9 +24,25 @@ def __init__(self, waveform_extractor): self.spikes = self.waveform_extractor.sorting.to_spike_vector( extremum_channel_inds=extremum_channel_inds, use_cache=False ) - - def _set_params(self, sparsity, max_dense_channels, ms_before, ms_after): - params = dict(sparsity=sparsity, max_dense_channels=max_dense_channels, ms_before=ms_before, ms_after=ms_after) + self.collisions = None + + def _set_params( + self, + sparsity, + max_dense_channels, + ms_before, + ms_after, + handle_collisions, + delta_collision_ms, + ): + params = dict( + sparsity=sparsity, + max_dense_channels=max_dense_channels, + ms_before=ms_before, + ms_after=ms_after, + handle_collisions=handle_collisions, + delta_collision_ms=delta_collision_ms, + ) return params def _select_extension_data(self, unit_ids): @@ -43,6 +62,11 @@ def _run(self, **job_kwargs): ms_before = self._params["ms_before"] ms_after = self._params["ms_after"] + # collisions + handle_collisions = self._params["handle_collisions"] + delta_collision_ms = self._params["delta_collision_ms"] + delta_collision_samples = int(delta_collision_ms / 1000 * we.sampling_frequency) + return_scaled = we._params["return_scaled"] unit_ids = we.unit_ids @@ -67,6 +91,8 @@ def _run(self, **job_kwargs): assert recording.get_num_channels() <= self._params["max_dense_channels"], "" sparsity = ChannelSparsity.create_dense(we) sparsity_inds = sparsity.unit_id_to_channel_indices + + # easier to use in chunk function as spikes use unit_index instead o id unit_inds_to_channel_indices = {unit_ind: sparsity_inds[unit_id] for unit_ind, unit_id in enumerate(unit_ids)} all_templates = we.get_all_templates() @@ -93,6 +119,8 @@ def _run(self, **job_kwargs): cut_out_before, cut_out_after, return_scaled, + handle_collisions, + delta_collision_samples, ) processor = ChunkRecordingExecutor( recording, @@ -104,10 +132,18 @@ def _run(self, **job_kwargs): **job_kwargs, ) out = processor.run() - (amp_scalings,) = zip(*out) + (amp_scalings, collisions) = zip(*out) amp_scalings = np.concatenate(amp_scalings) - self._extension_data[f"amplitude_scalings"] = amp_scalings + collisions_dict = {} + if handle_collisions: + for collision in collisions: + collisions_dict.update(collision) + self.collisions = collisions_dict + # Note: collisions are note in _extension_data because they are not pickable. We only store the indices + self._extension_data["collisions"] = np.array(list(collisions_dict.keys())) + + self._extension_data["amplitude_scalings"] = amp_scalings def get_data(self, outputs="concatenated"): """ @@ -154,6 +190,8 @@ def compute_amplitude_scalings( max_dense_channels=16, ms_before=None, ms_after=None, + handle_collisions=True, + delta_collision_ms=2, load_if_exists=False, outputs="concatenated", **job_kwargs, @@ -165,22 +203,27 @@ def compute_amplitude_scalings( ---------- waveform_extractor: WaveformExtractor The waveform extractor object - sparsity: ChannelSparsity + sparsity: ChannelSparsity, default: None If waveforms are not sparse, sparsity is required if the number of channels is greater than `max_dense_channels`. If the waveform extractor is sparse, its sparsity is automatically used. - By default None max_dense_channels: int, default: 16 Maximum number of channels to allow running without sparsity. To compute amplitude scaling using dense waveforms, set this to None, sparsity to None, and pass dense waveforms as input. - ms_before : float, optional + ms_before : float, default: None The cut out to apply before the spike peak to extract local waveforms. - If None, the WaveformExtractor ms_before is used, by default None - ms_after : float, optional + If None, the WaveformExtractor ms_before is used. + ms_after : float, default: None The cut out to apply after the spike peak to extract local waveforms. - If None, the WaveformExtractor ms_after is used, by default None + If None, the WaveformExtractor ms_after is used. + handle_collisions: bool, default: True + Whether to handle collisions between spikes. If True, the amplitude scaling of colliding spikes + (defined as spikes within `delta_collision_ms` ms and with overlapping sparsity) is computed by fitting a + multi-linear regression model (with `sklearn.LinearRegression`). If False, each spike is fitted independently. + delta_collision_ms: float, default: 2 + The maximum time difference in ms before and after a spike to gather colliding spikes. load_if_exists : bool, default: False Whether to load precomputed spike amplitudes, if they already exist. - outputs: str + outputs: str, default: 'concatenated' How the output should be returned: - 'concatenated' - 'by_unit' @@ -197,7 +240,14 @@ def compute_amplitude_scalings( sac = waveform_extractor.load_extension(AmplitudeScalingsCalculator.extension_name) else: sac = AmplitudeScalingsCalculator(waveform_extractor) - sac.set_params(sparsity=sparsity, max_dense_channels=max_dense_channels, ms_before=ms_before, ms_after=ms_after) + sac.set_params( + sparsity=sparsity, + max_dense_channels=max_dense_channels, + ms_before=ms_before, + ms_after=ms_after, + handle_collisions=handle_collisions, + delta_collision_ms=delta_collision_ms, + ) sac.run(**job_kwargs) amps = sac.get_data(outputs=outputs) @@ -218,6 +268,8 @@ def _init_worker_amplitude_scalings( cut_out_before, cut_out_after, return_scaled, + handle_collisions, + delta_collision_samples, ): # create a local dict per worker worker_ctx = {} @@ -229,14 +281,24 @@ def _init_worker_amplitude_scalings( worker_ctx["nafter"] = nafter worker_ctx["cut_out_before"] = cut_out_before worker_ctx["cut_out_after"] = cut_out_after - worker_ctx["margin"] = max(nbefore, nafter) worker_ctx["return_scaled"] = return_scaled worker_ctx["unit_inds_to_channel_indices"] = unit_inds_to_channel_indices + worker_ctx["handle_collisions"] = handle_collisions + worker_ctx["delta_collision_samples"] = delta_collision_samples + + if not handle_collisions: + worker_ctx["margin"] = max(nbefore, nafter) + else: + # in this case we extend the margin to be able to get with collisions outside the chunk + margin_waveforms = max(nbefore, nafter) + max_margin_collisions = delta_collision_samples + margin_waveforms + worker_ctx["margin"] = max_margin_collisions return worker_ctx def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx): + # from sklearn.linear_model import LinearRegression from scipy.stats import linregress # recover variables of the worker @@ -250,16 +312,14 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) cut_out_after = worker_ctx["cut_out_after"] margin = worker_ctx["margin"] return_scaled = worker_ctx["return_scaled"] + handle_collisions = worker_ctx["handle_collisions"] + delta_collision_samples = worker_ctx["delta_collision_samples"] spikes_in_segment = spikes[segment_slices[segment_index]] i0 = np.searchsorted(spikes_in_segment["sample_index"], start_frame) i1 = np.searchsorted(spikes_in_segment["sample_index"], end_frame) - local_waveforms = [] - templates = [] - scalings = [] - if i0 != i1: local_spikes = spikes_in_segment[i0:i1] traces_with_margin, left, right = get_chunk_with_margin( @@ -272,8 +332,26 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) offsets = recording.get_property("offset_to_uV") traces_with_margin = traces_with_margin.astype("float32") * gains + offsets - # get all waveforms - for spike in local_spikes: + # set colliding spikes apart (if needed) + if handle_collisions: + # local spikes with margin! + i0_margin = np.searchsorted(spikes_in_segment["sample_index"], start_frame - left) + i1_margin = np.searchsorted(spikes_in_segment["sample_index"], end_frame + right) + local_spikes_w_margin = spikes_in_segment[i0_margin:i1_margin] + collisions_local = find_collisions( + local_spikes, local_spikes_w_margin, delta_collision_samples, unit_inds_to_channel_indices + ) + else: + collisions_local = {} + + # compute the scaling for each spike + scalings = np.zeros(len(local_spikes), dtype=float) + # collision_global transforms local spike index to global spike index + collisions_global = {} + for spike_index, spike in enumerate(local_spikes): + if spike_index in collisions_local.keys(): + # we deal with overlapping spikes later + continue unit_index = spike["unit_index"] sample_index = spike["sample_index"] sparse_indices = unit_inds_to_channel_indices[unit_index] @@ -291,10 +369,335 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) else: local_waveform = traces_with_margin[cut_out_start:cut_out_end, sparse_indices] assert template.shape == local_waveform.shape - local_waveforms.append(local_waveform) - templates.append(template) + + # here we use linregress, which is equivalent to using sklearn LinearRegression with fit_intercept=True + # y = local_waveform.flatten() + # X = template.flatten()[:, np.newaxis] + # reg = LinearRegression(positive=True, fit_intercept=True).fit(X, y) + # scalings[spike_index] = reg.coef_[0] linregress_res = linregress(template.flatten(), local_waveform.flatten()) - scalings.append(linregress_res[0]) - scalings = np.array(scalings) + scalings[spike_index] = linregress_res[0] + + # deal with collisions + if len(collisions_local) > 0: + num_spikes_in_previous_segments = int( + np.sum([len(spikes[segment_slices[s]]) for s in range(segment_index)]) + ) + for spike_index, collision in collisions_local.items(): + scaled_amps = fit_collision( + collision, + traces_with_margin, + start_frame, + end_frame, + left, + right, + nbefore, + all_templates, + unit_inds_to_channel_indices, + cut_out_before, + cut_out_after, + ) + # the scaling for the current spike is at index 0 + scalings[spike_index] = scaled_amps[0] + + # make collision_dict indices "absolute" by adding i0 and the cumulative number of spikes in previous segments + collisions_global.update({spike_index + i0 + num_spikes_in_previous_segments: collision}) + else: + scalings = np.array([]) + collisions_global = {} + + return (scalings, collisions_global) + + +### Collision handling ### +def _are_unit_indices_overlapping(unit_inds_to_channel_indices, i, j): + """ + Returns True if the unit indices i and j are overlapping, False otherwise + + Parameters + ---------- + unit_inds_to_channel_indices: dict + A dictionary mapping unit indices to channel indices + i: int + The first unit index + j: int + The second unit index + + Returns + ------- + bool + True if the unit indices i and j are overlapping, False otherwise + """ + if len(np.intersect1d(unit_inds_to_channel_indices[i], unit_inds_to_channel_indices[j])) > 0: + return True + else: + return False + - return (scalings,) +def find_collisions(spikes, spikes_w_margin, delta_collision_samples, unit_inds_to_channel_indices): + """ + Finds the collisions between spikes. + + Parameters + ---------- + spikes: np.array + An array of spikes + spikes_w_margin: np.array + An array of spikes within the added margin + delta_collision_samples: int + The maximum number of samples between two spikes to consider them as overlapping + unit_inds_to_channel_indices: dict + A dictionary mapping unit indices to channel indices + + Returns + ------- + collision_spikes_dict: np.array + A dictionary with collisions. The key is the index of the spike with collision, the value is an + array of overlapping spikes, including the spike itself at position 0. + """ + # TODO: refactor to speed-up + collision_spikes_dict = {} + for spike_index, spike in enumerate(spikes): + # find the index of the spike within the spikes_w_margin + spike_index_w_margin = np.where(spikes_w_margin == spike)[0][0] + + # find the possible spikes per and post within delta_collision_samples + consecutive_window_pre = np.searchsorted( + spikes_w_margin["sample_index"], + spike["sample_index"] - delta_collision_samples, + ) + consecutive_window_post = np.searchsorted( + spikes_w_margin["sample_index"], + spike["sample_index"] + delta_collision_samples, + ) + # exclude the spike itself (it is included in the collision_spikes by construction) + pre_possible_consecutive_spike_indices = np.arange(consecutive_window_pre, spike_index_w_margin) + post_possible_consecutive_spike_indices = np.arange(spike_index_w_margin + 1, consecutive_window_post) + possible_overlapping_spike_indices = np.concatenate( + (pre_possible_consecutive_spike_indices, post_possible_consecutive_spike_indices) + ) + + # find the overlapping spikes in space as well + for possible_overlapping_spike_index in possible_overlapping_spike_indices: + if _are_unit_indices_overlapping( + unit_inds_to_channel_indices, + spike["unit_index"], + spikes_w_margin[possible_overlapping_spike_index]["unit_index"], + ): + if spike_index not in collision_spikes_dict: + collision_spikes_dict[spike_index] = np.array([spike]) + collision_spikes_dict[spike_index] = np.concatenate( + (collision_spikes_dict[spike_index], [spikes_w_margin[possible_overlapping_spike_index]]) + ) + return collision_spikes_dict + + +def fit_collision( + collision, + traces_with_margin, + start_frame, + end_frame, + left, + right, + nbefore, + all_templates, + unit_inds_to_channel_indices, + cut_out_before, + cut_out_after, +): + """ + Compute the best fit for a collision between a spike and its overlapping spikes. + The function first cuts out the traces around the spike and its overlapping spikes, then + fits a multi-linear regression model to the traces using the centered templates as predictors. + + Parameters + ---------- + collision: np.ndarray + A numpy array of shape (n_colliding_spikes, ) containing the colliding spikes (spike_dtype). + traces_with_margin: np.ndarray + A numpy array of shape (n_samples, n_channels) containing the traces with a margin. + start_frame: int + The start frame of the chunk for traces_with_margin. + end_frame: int + The end frame of the chunk for traces_with_margin. + left: int + The left margin of the chunk for traces_with_margin. + right: int + The right margin of the chunk for traces_with_margin. + nbefore: int + The number of samples before the spike to consider for the fit. + all_templates: np.ndarray + A numpy array of shape (n_units, n_samples, n_channels) containing the templates. + unit_inds_to_channel_indices: dict + A dictionary mapping unit indices to channel indices. + cut_out_before: int + The number of samples to cut out before the spike. + cut_out_after: int + The number of samples to cut out after the spike. + + Returns + ------- + np.ndarray + The fitted scaling factors for the colliding spikes. + """ + from sklearn.linear_model import LinearRegression + + # make center of the spike externally + sample_first_centered = np.min(collision["sample_index"]) - (start_frame - left) + sample_last_centered = np.max(collision["sample_index"]) - (start_frame - left) + + # construct sparsity as union between units' sparsity + sparse_indices = np.array([], dtype="int") + for spike in collision: + sparse_indices_i = unit_inds_to_channel_indices[spike["unit_index"]] + sparse_indices = np.union1d(sparse_indices, sparse_indices_i) + + local_waveform_start = max(0, sample_first_centered - cut_out_before) + local_waveform_end = min(traces_with_margin.shape[0], sample_last_centered + cut_out_after) + local_waveform = traces_with_margin[local_waveform_start:local_waveform_end, sparse_indices] + + y = local_waveform.T.flatten() + X = np.zeros((len(y), len(collision))) + for i, spike in enumerate(collision): + full_template = np.zeros_like(local_waveform) + # center wrt cutout traces + sample_centered = spike["sample_index"] - (start_frame - left) - local_waveform_start + template = all_templates[spike["unit_index"]][:, sparse_indices] + template_cut = template[nbefore - cut_out_before : nbefore + cut_out_after] + # deal with borders + if sample_centered - cut_out_before < 0: + full_template[: sample_centered + cut_out_after] = template_cut[cut_out_before - sample_centered :] + elif sample_centered + cut_out_after > end_frame + right: + full_template[sample_centered - cut_out_before :] = template_cut[: -cut_out_after - (end_frame + right)] + else: + full_template[sample_centered - cut_out_before : sample_centered + cut_out_after] = template_cut + X[:, i] = full_template.T.flatten() + + reg = LinearRegression(fit_intercept=True, positive=True).fit(X, y) + scalings = reg.coef_ + return scalings + + +# uncomment for debugging +# def plot_collisions(we, sparsity=None, num_collisions=None): +# """ +# Plot the fitting of collision spikes. + +# Parameters +# ---------- +# we : WaveformExtractor +# The WaveformExtractor object. +# sparsity : ChannelSparsity, default=None +# The ChannelSparsity. If None, only main channels are plotted. +# num_collisions : int, default=None +# Number of collisions to plot. If None, all collisions are plotted. +# """ +# assert we.is_extension("amplitude_scalings"), "Could not find amplitude scalings extension!" +# sac = we.load_extension("amplitude_scalings") +# handle_collisions = sac._params["handle_collisions"] +# assert handle_collisions, "Amplitude scalings was run without handling collisions!" +# scalings = sac.get_data() + +# # overlapping_mask = sac.overlapping_mask +# # num_collisions = num_collisions or len(overlapping_mask) +# spikes = sac.spikes +# collisions = sac._extension_data[f"collisions"] +# collision_keys = list(collisions.keys()) +# num_collisions = num_collisions or len(collisions) +# num_collisions = min(num_collisions, len(collisions)) + +# for i in range(num_collisions): +# overlapping_spikes = collisions[collision_keys[i]] +# ax = plot_one_collision( +# we, collision_keys[i], overlapping_spikes, spikes, scalings=scalings, sparsity=sparsity +# ) + + +# def plot_one_collision( +# we, +# spike_index, +# overlapping_spikes, +# spikes, +# scalings=None, +# sparsity=None, +# cut_out_samples=100, +# ax=None +# ): +# import matplotlib.pyplot as plt + +# if ax is None: +# fig, ax = plt.subplots() + +# recording = we.recording +# nbefore_nafter_max = max(we.nafter, we.nbefore) +# cut_out_samples = max(cut_out_samples, nbefore_nafter_max) + +# if sparsity is not None: +# unit_inds_to_channel_indices = sparsity.unit_id_to_channel_indices +# sparse_indices = np.array([], dtype="int") +# for spike in overlapping_spikes: +# sparse_indices_i = unit_inds_to_channel_indices[we.unit_ids[spike["unit_index"]]] +# sparse_indices = np.union1d(sparse_indices, sparse_indices_i) +# else: +# sparse_indices = np.unique(overlapping_spikes["channel_index"]) + +# channel_ids = recording.channel_ids[sparse_indices] + +# center_spike = overlapping_spikes[0] +# max_delta = np.max( +# [ +# np.abs(center_spike["sample_index"] - np.min(overlapping_spikes[1:]["sample_index"])), +# np.abs(center_spike["sample_index"] - np.max(overlapping_spikes[1:]["sample_index"])), +# ] +# ) +# sf = max(0, center_spike["sample_index"] - max_delta - cut_out_samples) +# ef = min( +# center_spike["sample_index"] + max_delta + cut_out_samples, +# recording.get_num_samples(segment_index=center_spike["segment_index"]), +# ) +# tr_overlap = recording.get_traces(start_frame=sf, end_frame=ef, channel_ids=channel_ids, return_scaled=True) +# ts = np.arange(sf, ef) / recording.sampling_frequency * 1000 +# max_tr = np.max(np.abs(tr_overlap)) + +# for ch, tr in enumerate(tr_overlap.T): +# _ = ax.plot(ts, tr + 1.2 * ch * max_tr, color="k") +# ax.text(ts[0], 1.2 * ch * max_tr - 0.3 * max_tr, f"Ch:{channel_ids[ch]}") + +# used_labels = [] +# for i, spike in enumerate(overlapping_spikes): +# label = f"U{spike['unit_index']}" +# if label in used_labels: +# label = None +# else: +# used_labels.append(label) +# ax.axvline( +# spike["sample_index"] / recording.sampling_frequency * 1000, color=f"C{spike['unit_index']}", label=label +# ) + +# if scalings is not None: +# fitted_traces = np.zeros_like(tr_overlap) + +# all_templates = we.get_all_templates() +# for i, spike in enumerate(overlapping_spikes): +# template = all_templates[spike["unit_index"]] +# overlap_index = np.where(spikes == spike)[0][0] +# template_scaled = scalings[overlap_index] * template +# template_scaled_sparse = template_scaled[:, sparse_indices] +# sample_start = spike["sample_index"] - we.nbefore +# sample_end = sample_start + template_scaled_sparse.shape[0] + +# fitted_traces[sample_start - sf : sample_end - sf] += template_scaled_sparse + +# for ch, temp in enumerate(template_scaled_sparse.T): +# ts_template = np.arange(sample_start, sample_end) / recording.sampling_frequency * 1000 +# _ = ax.plot(ts_template, temp + 1.2 * ch * max_tr, color=f"C{spike['unit_index']}", ls="--") + +# for ch, tr in enumerate(fitted_traces.T): +# _ = ax.plot(ts, tr + 1.2 * ch * max_tr, color="gray", alpha=0.7) + +# fitted_line = ax.get_lines()[-1] +# fitted_line.set_label("Fitted") + +# ax.legend() +# ax.set_title(f"Spike {spike_index} - sample {center_spike['sample_index']}") +# return ax