From f737f5a4b4b998321ff2999ec57915f1c1b6dc82 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 24 Aug 2023 15:14:55 +0200 Subject: [PATCH 1/5] wip collisions --- .../postprocessing/amplitude_scalings.py | 384 +++++++++++++++++- 1 file changed, 370 insertions(+), 14 deletions(-) diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 3ebeafcfec..7539e4d0b7 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -22,8 +22,25 @@ def __init__(self, waveform_extractor): extremum_channel_inds=extremum_channel_inds, use_cache=False ) - def _set_params(self, sparsity, max_dense_channels, ms_before, ms_after): - params = dict(sparsity=sparsity, max_dense_channels=max_dense_channels, ms_before=ms_before, ms_after=ms_after) + def _set_params( + self, + sparsity, + max_dense_channels, + ms_before, + ms_after, + handle_collisions, + max_consecutive_collisions, + delta_collision_ms, + ): + params = dict( + sparsity=sparsity, + max_dense_channels=max_dense_channels, + ms_before=ms_before, + ms_after=ms_after, + handle_collisions=handle_collisions, + max_consecutive_collisions=max_consecutive_collisions, + delta_collision_ms=delta_collision_ms, + ) return params def _select_extension_data(self, unit_ids): @@ -43,6 +60,12 @@ def _run(self, **job_kwargs): ms_before = self._params["ms_before"] ms_after = self._params["ms_after"] + # collisions + handle_collisions = self._params["handle_collisions"] + max_consecutive_collisions = self._params["max_consecutive_collisions"] + delta_collision_ms = self._params["delta_collision_ms"] + delta_collision_samples = int(delta_collision_ms / 1000 * we.sampling_frequency) + return_scaled = we._params["return_scaled"] unit_ids = we.unit_ids @@ -67,6 +90,8 @@ def _run(self, **job_kwargs): assert recording.get_num_channels() <= self._params["max_dense_channels"], "" sparsity = ChannelSparsity.create_dense(we) sparsity_inds = sparsity.unit_id_to_channel_indices + + # easier to use in chunk function as spikes use unit_index instead o id unit_inds_to_channel_indices = {unit_ind: sparsity_inds[unit_id] for unit_ind, unit_id in enumerate(unit_ids)} all_templates = we.get_all_templates() @@ -93,6 +118,9 @@ def _run(self, **job_kwargs): cut_out_before, cut_out_after, return_scaled, + handle_collisions, + max_consecutive_collisions, + delta_collision_samples, ) processor = ChunkRecordingExecutor( recording, @@ -154,6 +182,9 @@ def compute_amplitude_scalings( max_dense_channels=16, ms_before=None, ms_after=None, + handle_collisions=False, + max_consecutive_collisions=3, + delta_collision_ms=2, load_if_exists=False, outputs="concatenated", **job_kwargs, @@ -165,22 +196,29 @@ def compute_amplitude_scalings( ---------- waveform_extractor: WaveformExtractor The waveform extractor object - sparsity: ChannelSparsity + sparsity: ChannelSparsity, default: None If waveforms are not sparse, sparsity is required if the number of channels is greater than `max_dense_channels`. If the waveform extractor is sparse, its sparsity is automatically used. - By default None max_dense_channels: int, default: 16 Maximum number of channels to allow running without sparsity. To compute amplitude scaling using dense waveforms, set this to None, sparsity to None, and pass dense waveforms as input. - ms_before : float, optional + ms_before : float, default: None The cut out to apply before the spike peak to extract local waveforms. - If None, the WaveformExtractor ms_before is used, by default None - ms_after : float, optional + If None, the WaveformExtractor ms_before is used. + ms_after : float, default: None The cut out to apply after the spike peak to extract local waveforms. - If None, the WaveformExtractor ms_after is used, by default None + If None, the WaveformExtractor ms_after is used. + handle_collisions: bool, default: False + Whether to handle collisions between spikes. If True, the amplitude scaling of colliding spikes + (defined as spikes within `delta_collision_ms` ms and with overlapping sparsity) is computed by fitting a + multi-linear regression model (with `sklearn.LinearRegression`). If False, each spike is fitted independently. + max_consecutive_collisions: int, default: 3 + The maximum number of consecutive collisions to handle on each side of a spike. + delta_collision_ms: float, default: 2 + The maximum time difference in ms between two spikes to be considered as colliding. load_if_exists : bool, default: False Whether to load precomputed spike amplitudes, if they already exist. - outputs: str + outputs: str, default: 'concatenated' How the output should be returned: - 'concatenated' - 'by_unit' @@ -197,7 +235,15 @@ def compute_amplitude_scalings( sac = waveform_extractor.load_extension(AmplitudeScalingsCalculator.extension_name) else: sac = AmplitudeScalingsCalculator(waveform_extractor) - sac.set_params(sparsity=sparsity, max_dense_channels=max_dense_channels, ms_before=ms_before, ms_after=ms_after) + sac.set_params( + sparsity=sparsity, + max_dense_channels=max_dense_channels, + ms_before=ms_before, + ms_after=ms_after, + handle_collisions=handle_collisions, + max_consecutive_collisions=max_consecutive_collisions, + delta_collision_ms=delta_collision_ms, + ) sac.run(**job_kwargs) amps = sac.get_data(outputs=outputs) @@ -218,6 +264,9 @@ def _init_worker_amplitude_scalings( cut_out_before, cut_out_after, return_scaled, + handle_collisions, + max_consecutive_collisions, + delta_collision_samples, ): # create a local dict per worker worker_ctx = {} @@ -229,9 +278,18 @@ def _init_worker_amplitude_scalings( worker_ctx["nafter"] = nafter worker_ctx["cut_out_before"] = cut_out_before worker_ctx["cut_out_after"] = cut_out_after - worker_ctx["margin"] = max(nbefore, nafter) worker_ctx["return_scaled"] = return_scaled worker_ctx["unit_inds_to_channel_indices"] = unit_inds_to_channel_indices + worker_ctx["handle_collisions"] = handle_collisions + worker_ctx["max_consecutive_collisions"] = max_consecutive_collisions + worker_ctx["delta_collision_samples"] = delta_collision_samples + + if not handle_collisions: + worker_ctx["margin"] = max(nbefore, nafter) + else: + margin_waveforms = max(nbefore, nafter) + max_margin_collisions = int(max_consecutive_collisions * delta_collision_samples) + worker_ctx["margin"] = max(margin_waveforms, max_margin_collisions) return worker_ctx @@ -250,6 +308,9 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) cut_out_after = worker_ctx["cut_out_after"] margin = worker_ctx["margin"] return_scaled = worker_ctx["return_scaled"] + handle_collisions = worker_ctx["handle_collisions"] + max_consecutive_collisions = worker_ctx["max_consecutive_collisions"] + delta_collision_samples = worker_ctx["delta_collision_samples"] spikes_in_segment = spikes[segment_slices[segment_index]] @@ -272,8 +333,24 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) offsets = recording.get_property("offset_to_uV") traces_with_margin = traces_with_margin.astype("float32") * gains + offsets + # set colliding spikes apart (if needed) + if handle_collisions: + overlapping_mask = _find_overlapping_mask( + local_spikes, max_consecutive_collisions, delta_collision_samples, unit_inds_to_channel_indices + ) + overlapping_spike_indices = overlapping_mask[:, max_consecutive_collisions] + print( + f"Found {len(overlapping_spike_indices)} overlapping spikes in segment {segment_index}! - chunk {start_frame} - {end_frame}" + ) + else: + overlapping_spike_indices = np.array([], dtype=int) + # get all waveforms - for spike in local_spikes: + scalings = np.zeros(len(local_spikes), dtype=float) + for spike_index, spike in enumerate(local_spikes): + if spike_index in overlapping_spike_indices: + # we deal with overlapping spikes later + continue unit_index = spike["unit_index"] sample_index = spike["sample_index"] sparse_indices = unit_inds_to_channel_indices[unit_index] @@ -294,7 +371,286 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) local_waveforms.append(local_waveform) templates.append(template) linregress_res = linregress(template.flatten(), local_waveform.flatten()) - scalings.append(linregress_res[0]) - scalings = np.array(scalings) + scalings[spike_index] = linregress_res[0] + + # deal with collisions + if len(overlapping_spike_indices) > 0: + for overlapping in overlapping_mask: + spike_index = overlapping[max_consecutive_collisions] + overlapping_spikes = local_spikes[overlapping[overlapping >= 0]] + scaled_amps = _fit_collision( + overlapping_spikes, + traces_with_margin, + start_frame, + end_frame, + left, + right, + nbefore, + all_templates, + unit_inds_to_channel_indices, + cut_out_before, + cut_out_after, + ) + # get the right amplitude scaling + scalings[spike_index] = scaled_amps[np.where(overlapping >= 0)[0] == max_consecutive_collisions] return (scalings,) + + +### Collision handling ### +def _are_unit_indices_overlapping(unit_inds_to_channel_indices, i, j): + """ + Returns True if the unit indices i and j are overlapping, False otherwise + + Parameters + ---------- + unit_inds_to_channel_indices: dict + A dictionary mapping unit indices to channel indices + i: int + The first unit index + j: int + The second unit index + + Returns + ------- + bool + True if the unit indices i and j are overlapping, False otherwise + """ + if len(np.intersect1d(unit_inds_to_channel_indices[i], unit_inds_to_channel_indices[j])) > 0: + return True + else: + return False + + +def _find_overlapping_mask(spikes, max_consecutive_spikes, delta_overlap_samples, unit_inds_to_channel_indices): + """ + Finds the overlapping spikes for each spike in spikes and returns a boolean mask of shape + (n_spikes, 2 * max_consecutive_spikes + 1). + + Parameters + ---------- + spikes: np.array + An array of spikes + max_consecutive_spikes: int + The maximum number of consecutive spikes to consider + delta_overlap_samples: int + The maximum number of samples between two spikes to consider them as overlapping + unit_inds_to_channel_indices: dict + A dictionary mapping unit indices to channel indices + + Returns + ------- + overlapping_mask: np.array + A boolean mask of shape (n_spikes, 2 * max_consecutive_spikes + 1) where the central column (max_consecutive_spikes) + is the current spike index, while the other columns are the indices of the overlapping spikes. The first + max_consecutive_spikes columns are the pre-overlapping spikes, while the last max_consecutive_spikes columns are + the post-overlapping spikes. + """ + + # overlapping_mask is a matrix of shape (n_spikes, 2 * max_consecutive_spikes + 1) + # the central column (max_consecutive_spikes) is the current spike index, while the other columns are the + # indices of the overlapping spikes. The first max_consecutive_spikes columns are the pre-overlapping spikes, + # while the last max_consecutive_spikes columns are the post-overlapping spikes + # Rows with all -1 are non-colliding spikes and are removed later + overlapping_mask_full = -1 * np.ones((len(spikes), 2 * max_consecutive_spikes + 1), dtype=int) + overlapping_mask_full[:, max_consecutive_spikes] = np.arange(len(spikes)) + + for i, spike in enumerate(spikes): + # find the possible spikes per and post within max_consecutive_spikes * delta_overlap_samples + consecutive_window_pre = np.searchsorted( + spikes["sample_index"], + spike["sample_index"] - max_consecutive_spikes * delta_overlap_samples, + ) + consecutive_window_post = np.searchsorted( + spikes["sample_index"], + spike["sample_index"] + max_consecutive_spikes * delta_overlap_samples, + ) + pre_possible_consecutive_spikes = spikes[consecutive_window_pre:i][::-1] + post_possible_consecutive_spikes = spikes[i + 1 : consecutive_window_post] + + # here we fill in the overlapping information by consecutively looping through the possible consecutive spikes + # and checking the spatial overlap and the delay with the previous overlapping spike + # pre and post are hanlded separately. Note that the pre-spikes are already sorted backwards + + # overlap_rank keeps track of the rank of consecutive collisions (i.e., rank 0 is the first, rank 1 is the second, etc.) + # this is needed because we are just considering spikes with spatial overlap, while the possible consecutive spikes + # only looked at the temporal overlap + overlap_rank = 0 + if len(pre_possible_consecutive_spikes) > 0: + for c_pre, spike_consecutive_pre in enumerate(pre_possible_consecutive_spikes[::-1]): + if _are_unit_indices_overlapping( + unit_inds_to_channel_indices, spike["unit_index"], spike_consecutive_pre["unit_index"] + ): + if ( + spikes[overlapping_mask_full[i, max_consecutive_spikes - overlap_rank]]["sample_index"] + - spike_consecutive_pre["sample_index"] + < delta_overlap_samples + ): + overlapping_mask_full[i, max_consecutive_spikes - overlap_rank - 1] = i - 1 - c_pre + overlap_rank += 1 + else: + break + # if overlap_rank > 1: + # print(f"\tHigher order pre-overlap for spike {i}!") + + overlap_rank = 0 + if len(post_possible_consecutive_spikes) > 0: + for c_post, spike_consecutive_post in enumerate(post_possible_consecutive_spikes): + if _are_unit_indices_overlapping( + unit_inds_to_channel_indices, spike["unit_index"], spike_consecutive_post["unit_index"] + ): + if ( + spike_consecutive_post["sample_index"] + - spikes[overlapping_mask_full[i, max_consecutive_spikes + overlap_rank]]["sample_index"] + < delta_overlap_samples + ): + overlapping_mask_full[i, max_consecutive_spikes + overlap_rank + 1] = i + 1 + c_post + overlap_rank += 1 + else: + break + # if overlap_rank > 1: + # print(f"\tHigher order post-overlap for spike {i}!") + + # in case no collisions were found, we set the central column to -1 so that we can easily identify the non-colliding spikes + if np.sum(overlapping_mask_full[i] != -1) == 1: + overlapping_mask_full[i, max_consecutive_spikes] = -1 + + # only return rows with collisions + overlapping_inds = [] + for i, overlapping in enumerate(overlapping_mask_full): + if np.any(overlapping >= 0): + overlapping_inds.append(i) + overlapping_mask = overlapping_mask_full[overlapping_inds] + + return overlapping_mask + + +def _fit_collision( + overlapping_spikes, + traces_with_margin, + start_frame, + end_frame, + left, + right, + nbefore, + all_templates, + unit_inds_to_channel_indices, + cut_out_before, + cut_out_after, +): + """ """ + from sklearn.linear_model import LinearRegression + + sample_first_centered = overlapping_spikes[0]["sample_index"] - start_frame - left + sample_last_centered = overlapping_spikes[-1]["sample_index"] - start_frame - left + + # construct sparsity as union between units' sparsity + sparse_indices = np.array([], dtype="int") + for spike in overlapping_spikes: + sparse_indices_i = unit_inds_to_channel_indices[spike["unit_index"]] + sparse_indices = np.union1d(sparse_indices, sparse_indices_i) + + local_waveform_start = max(0, sample_first_centered - cut_out_before) + local_waveform_end = min(traces_with_margin.shape[0], sample_last_centered + cut_out_after) + local_waveform = traces_with_margin[local_waveform_start:local_waveform_end, sparse_indices] + + y = local_waveform.T.flatten() + X = np.zeros((len(y), len(overlapping_spikes))) + for i, spike in enumerate(overlapping_spikes): + full_template = np.zeros_like(local_waveform) + # center wrt cutout traces + sample_centered = spike["sample_index"] - local_waveform_start + template = all_templates[spike["unit_index"]][:, sparse_indices] + template_cut = template[nbefore - cut_out_before : nbefore + cut_out_after] + # deal with borders + if sample_centered - cut_out_before < 0: + full_template[: sample_centered + cut_out_after] = template_cut[cut_out_before - sample_centered :] + elif sample_centered + cut_out_after > end_frame + right: + full_template[sample_centered - cut_out_before :] = template_cut[: -cut_out_after - (end_frame + right)] + else: + full_template[sample_centered - cut_out_before : sample_centered + cut_out_after] = template_cut + X[:, i] = full_template.T.flatten() + + reg = LinearRegression().fit(X, y) + amps = reg.coef_ + return amps + + +# TODO: fix this! +# def plot_overlapping_spikes(we, overlap, +# spikes, cut_out_samples=100, +# max_consecutive_spikes=3, +# sparsity=None, +# fitted_amps=None): +# recording = we.recording +# nbefore_nafter_max = max(we.nafter, we.nbefore) +# cut_out_samples = max(cut_out_samples, nbefore_nafter_max) +# spike_index = overlap[max_consecutive_spikes] +# overlap_indices = overlap[overlap != -1] +# overlapping_spikes = spikes[overlap_indices] + +# if sparsity is not None: +# unit_inds_to_channel_indices = sparsity.unit_id_to_channel_indices +# sparse_indices = np.array([], dtype="int") +# for spike in overlapping_spikes: +# sparse_indices_i = unit_inds_to_channel_indices[we.unit_ids[spike["unit_index"]]] +# sparse_indices = np.union1d(sparse_indices, sparse_indices_i) +# else: +# sparse_indices = np.unique(overlapping_spikes["channel_index"]) + +# channel_ids = recording.channel_ids[sparse_indices] + +# center_spike = spikes[spike_index]["sample_index"] +# max_delta = np.max([np.abs(center_spike - overlapping_spikes[0]["sample_index"]), +# np.abs(center_spike - overlapping_spikes[-1]["sample_index"])]) +# sf = center_spike - max_delta - cut_out_samples +# ef = center_spike + max_delta + cut_out_samples +# tr_overlap = recording.get_traces(start_frame=sf, +# end_frame=ef, +# channel_ids=channel_ids, return_scaled=True) +# ts = np.arange(sf, ef) / recording.sampling_frequency * 1000 +# max_tr = np.max(np.abs(tr_overlap)) +# fig, ax = plt.subplots() +# for ch, tr in enumerate(tr_overlap.T): +# _ = ax.plot(ts, tr + 1.2 * ch * max_tr, color="k") +# ax.text(ts[0], 1.2 * ch * max_tr - 0.3 * max_tr, f"Ch:{channel_ids[ch]}") + +# used_labels = [] +# for spike in overlapping_spikes: +# label = f"U{spike['unit_index']}" +# if label in used_labels: +# label = None +# else: +# used_labels.append(label) +# ax.axvline(spike["sample_index"] / recording.sampling_frequency * 1000, +# color=f"C{spike['unit_index']}", label=label) + +# if fitted_amps is not None: +# fitted_traces = np.zeros_like(tr_overlap) + +# all_templates = we.get_all_templates() +# for i, spike in enumerate(overlapping_spikes): +# template = all_templates[spike["unit_index"]] +# template_scaled = fitted_amps[overlap_indices[i]] * template +# template_scaled_sparse = template_scaled[:, sparse_indices] +# sample_start = spike["sample_index"] - we.nbefore +# sample_end = sample_start + template_scaled_sparse.shape[0] + +# fitted_traces[sample_start - sf: sample_end - sf] += template_scaled_sparse + +# for ch, temp in enumerate(template_scaled_sparse.T): + +# ts_template = np.arange(sample_start, sample_end) / recording.sampling_frequency * 1000 +# _ = ax.plot(ts_template, temp + 1.2 * ch * max_tr, color=f"C{spike['unit_index']}", +# ls="--") + +# for ch, tr in enumerate(fitted_traces.T): +# _ = ax.plot(ts, tr + 1.2 * ch * max_tr, color="gray", alpha=0.7) + +# fitted_line = ax.get_lines()[-1] +# fitted_line.set_label("Fitted") + + +# ax.legend() +# ax.set_title(f"Spike {spike_index} - sample {center_spike}") +# return tr_overlap, ax From b9391a69c26f027e40da7cf0c3b7cffbf68b2d5e Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 28 Aug 2023 09:55:29 +0200 Subject: [PATCH 2/5] wip collisions --- .../postprocessing/amplitude_scalings.py | 301 ++++++++++++------ 1 file changed, 203 insertions(+), 98 deletions(-) diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 7539e4d0b7..d367ef4f22 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -21,6 +21,7 @@ def __init__(self, waveform_extractor): self.spikes = self.waveform_extractor.sorting.to_spike_vector( extremum_channel_inds=extremum_channel_inds, use_cache=False ) + self.overlapping_mask = None def _set_params( self, @@ -132,8 +133,30 @@ def _run(self, **job_kwargs): **job_kwargs, ) out = processor.run() - (amp_scalings,) = zip(*out) + (amp_scalings, overlapping_mask) = zip(*out) amp_scalings = np.concatenate(amp_scalings) + if handle_collisions > 0: + from ..core.job_tools import divide_recording_into_chunks + + overlapping_mask_corrected = [] + all_chunks = divide_recording_into_chunks(processor.recording, processor.chunk_size) + num_spikes_so_far = 0 + for i, overlapping in enumerate(overlapping_mask): + if i == 0: + continue + segment_index = all_chunks[i - 1][0] + spikes_in_segment = self.spikes[segment_slices[segment_index]] + i0 = np.searchsorted(spikes_in_segment["sample_index"], all_chunks[i - 1][1]) + i1 = np.searchsorted(spikes_in_segment["sample_index"], all_chunks[i - 1][2]) + num_spikes_so_far += i1 - i0 + overlapping_corrected = overlapping.copy() + overlapping_corrected[overlapping_corrected >= 0] += num_spikes_so_far + overlapping_mask_corrected.append(overlapping_corrected) + overlapping_mask = np.concatenate(overlapping_mask_corrected) + print(f"Found {len(overlapping_mask)} overlapping spikes") + self.overlapping_mask = overlapping_mask + else: + overlapping_mask = np.concatenate(overlapping_mask) self._extension_data[f"amplitude_scalings"] = amp_scalings @@ -314,13 +337,10 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) spikes_in_segment = spikes[segment_slices[segment_index]] + # TODO: handle spikes in margin! i0 = np.searchsorted(spikes_in_segment["sample_index"], start_frame) i1 = np.searchsorted(spikes_in_segment["sample_index"], end_frame) - local_waveforms = [] - templates = [] - scalings = [] - if i0 != i1: local_spikes = spikes_in_segment[i0:i1] traces_with_margin, left, right = get_chunk_with_margin( @@ -335,13 +355,10 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) # set colliding spikes apart (if needed) if handle_collisions: - overlapping_mask = _find_overlapping_mask( + overlapping_mask = find_overlapping_mask( local_spikes, max_consecutive_collisions, delta_collision_samples, unit_inds_to_channel_indices ) overlapping_spike_indices = overlapping_mask[:, max_consecutive_collisions] - print( - f"Found {len(overlapping_spike_indices)} overlapping spikes in segment {segment_index}! - chunk {start_frame} - {end_frame}" - ) else: overlapping_spike_indices = np.array([], dtype=int) @@ -368,17 +385,18 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) else: local_waveform = traces_with_margin[cut_out_start:cut_out_end, sparse_indices] assert template.shape == local_waveform.shape - local_waveforms.append(local_waveform) - templates.append(template) + linregress_res = linregress(template.flatten(), local_waveform.flatten()) scalings[spike_index] = linregress_res[0] # deal with collisions if len(overlapping_spike_indices) > 0: for overlapping in overlapping_mask: + # the current spike is the one at the 'max_consecutive_collisions' position spike_index = overlapping[max_consecutive_collisions] overlapping_spikes = local_spikes[overlapping[overlapping >= 0]] - scaled_amps = _fit_collision( + current_spike_index_within_overlapping = np.where(overlapping >= 0)[0] == max_consecutive_collisions + scaled_amps = fit_collision( overlapping_spikes, traces_with_margin, start_frame, @@ -392,9 +410,12 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) cut_out_after, ) # get the right amplitude scaling - scalings[spike_index] = scaled_amps[np.where(overlapping >= 0)[0] == max_consecutive_collisions] + scalings[spike_index] = scaled_amps[current_spike_index_within_overlapping] + else: + scalings = np.array([]) + overlapping_mask = np.array([], shape=(0, max_consecutive_collisions + 1)) - return (scalings,) + return (scalings, overlapping_mask) ### Collision handling ### @@ -422,7 +443,7 @@ def _are_unit_indices_overlapping(unit_inds_to_channel_indices, i, j): return False -def _find_overlapping_mask(spikes, max_consecutive_spikes, delta_overlap_samples, unit_inds_to_channel_indices): +def find_overlapping_mask(spikes, max_consecutive_spikes, delta_overlap_samples, unit_inds_to_channel_indices): """ Finds the overlapping spikes for each spike in spikes and returns a boolean mask of shape (n_spikes, 2 * max_consecutive_spikes + 1). @@ -525,7 +546,7 @@ def _find_overlapping_mask(spikes, max_consecutive_spikes, delta_overlap_samples return overlapping_mask -def _fit_collision( +def fit_collision( overlapping_spikes, traces_with_margin, start_frame, @@ -537,8 +558,41 @@ def _fit_collision( unit_inds_to_channel_indices, cut_out_before, cut_out_after, + debug=True, ): - """ """ + """ + Compute the best fit for a collision between a spike and its overlapping spikes. + + Parameters + ---------- + overlapping_spikes: np.ndarray + A numpy array of shape (n_overlapping_spikes, ) containing the overlapping spikes (spike_dtype). + traces_with_margin: np.ndarray + A numpy array of shape (n_samples, n_channels) containing the traces with a margin. + start_frame: int + The start frame of the chunk for traces_with_margin. + end_frame: int + The end frame of the chunk for traces_with_margin. + left: int + The left margin of the chunk for traces_with_margin. + right: int + The right margin of the chunk for traces_with_margin. + nbefore: int + The number of samples before the spike to consider for the fit. + all_templates: np.ndarray + A numpy array of shape (n_units, n_samples, n_channels) containing the templates. + unit_inds_to_channel_indices: dict + A dictionary mapping unit indices to channel indices. + cut_out_before: int + The number of samples to cut out before the spike. + cut_out_after: int + The number of samples to cut out after the spike. + + Returns + ------- + np.ndarray + The fitted scaling factors for the overlapping spikes. + """ from sklearn.linear_model import LinearRegression sample_first_centered = overlapping_spikes[0]["sample_index"] - start_frame - left @@ -550,6 +604,7 @@ def _fit_collision( sparse_indices_i = unit_inds_to_channel_indices[spike["unit_index"]] sparse_indices = np.union1d(sparse_indices, sparse_indices_i) + # TODO: check alignment!!! local_waveform_start = max(0, sample_first_centered - cut_out_before) local_waveform_end = min(traces_with_margin.shape[0], sample_last_centered + cut_out_after) local_waveform = traces_with_margin[local_waveform_start:local_waveform_end, sparse_indices] @@ -559,7 +614,7 @@ def _fit_collision( for i, spike in enumerate(overlapping_spikes): full_template = np.zeros_like(local_waveform) # center wrt cutout traces - sample_centered = spike["sample_index"] - local_waveform_start + sample_centered = spike["sample_index"] - start_frame - left - local_waveform_start template = all_templates[spike["unit_index"]][:, sparse_indices] template_cut = template[nbefore - cut_out_before : nbefore + cut_out_after] # deal with borders @@ -571,86 +626,136 @@ def _fit_collision( full_template[sample_centered - cut_out_before : sample_centered + cut_out_after] = template_cut X[:, i] = full_template.T.flatten() + if debug: + import matplotlib.pyplot as plt + + fig, ax = plt.subplots() + max_tr = np.max(np.abs(local_waveform)) + + _ = ax.plot(y, color="k") + + for i, spike in enumerate(overlapping_spikes): + _ = ax.plot(X[:, i], color=f"C{i}", alpha=0.5) + plt.show() + reg = LinearRegression().fit(X, y) - amps = reg.coef_ - return amps + scalings = reg.coef_ + return scalings + + +def plot_collisions(we, sparsity=None, num_collisions=None): + """ + Plot the fitting of collision spikes. + + Parameters + ---------- + we : WaveformExtractor + The WaveformExtractor object. + sparsity : ChannelSparsity, default=None + The ChannelSparsity. If None, only main channels are plotted. + num_collisions : int, default=None + Number of collisions to plot. If None, all collisions are plotted. + """ + assert we.is_extension("amplitude_scalings"), "Could not find amplitude scalings extension!" + sac = we.load_extension("amplitude_scalings") + handle_collisions = sac._params["handle_collisions"] + assert handle_collisions, "Amplitude scalings was run without handling collisions!" + scalings = sac.get_data() + + overlapping_mask = sac.overlapping_mask + num_collisions = num_collisions or len(overlapping_mask) + spikes = sac.spikes + max_consecutive_collisions = sac._params["max_consecutive_collisions"] + + for i in range(num_collisions): + ax = _plot_one_collision( + we, overlapping_mask[i], spikes, scalings=scalings, max_consecutive_collisions=max_consecutive_collisions + ) + + +def _plot_one_collision( + we, + overlap, + spikes, + scalings=None, + sparsity=None, + cut_out_samples=100, + max_consecutive_collisions=3, +): + import matplotlib.pyplot as plt + + recording = we.recording + nbefore_nafter_max = max(we.nafter, we.nbefore) + cut_out_samples = max(cut_out_samples, nbefore_nafter_max) + spike_index = overlap[max_consecutive_collisions] + overlap_indices = overlap[overlap != -1] + overlapping_spikes = spikes[overlap_indices] + + if sparsity is not None: + unit_inds_to_channel_indices = sparsity.unit_id_to_channel_indices + sparse_indices = np.array([], dtype="int") + for spike in overlapping_spikes: + sparse_indices_i = unit_inds_to_channel_indices[we.unit_ids[spike["unit_index"]]] + sparse_indices = np.union1d(sparse_indices, sparse_indices_i) + else: + sparse_indices = np.unique(overlapping_spikes["channel_index"]) + + channel_ids = recording.channel_ids[sparse_indices] + + center_spike = spikes[spike_index] + max_delta = np.max( + [ + np.abs(center_spike["sample_index"] - overlapping_spikes[0]["sample_index"]), + np.abs(center_spike["sample_index"] - overlapping_spikes[-1]["sample_index"]), + ] + ) + sf = max(0, center_spike["sample_index"] - max_delta - cut_out_samples) + ef = min( + center_spike["sample_index"] + max_delta + cut_out_samples, + recording.get_num_samples(segment_index=center_spike["segment_index"]), + ) + tr_overlap = recording.get_traces(start_frame=sf, end_frame=ef, channel_ids=channel_ids, return_scaled=True) + ts = np.arange(sf, ef) / recording.sampling_frequency * 1000 + max_tr = np.max(np.abs(tr_overlap)) + fig, ax = plt.subplots() + for ch, tr in enumerate(tr_overlap.T): + _ = ax.plot(ts, tr + 1.2 * ch * max_tr, color="k") + ax.text(ts[0], 1.2 * ch * max_tr - 0.3 * max_tr, f"Ch:{channel_ids[ch]}") + + used_labels = [] + for spike in overlapping_spikes: + label = f"U{spike['unit_index']}" + if label in used_labels: + label = None + else: + used_labels.append(label) + ax.axvline( + spike["sample_index"] / recording.sampling_frequency * 1000, color=f"C{spike['unit_index']}", label=label + ) + + if scalings is not None: + fitted_traces = np.zeros_like(tr_overlap) + + all_templates = we.get_all_templates() + for i, spike in enumerate(overlapping_spikes): + template = all_templates[spike["unit_index"]] + template_scaled = scalings[overlap_indices[i]] * template + template_scaled_sparse = template_scaled[:, sparse_indices] + sample_start = spike["sample_index"] - we.nbefore + sample_end = sample_start + template_scaled_sparse.shape[0] + + fitted_traces[sample_start - sf : sample_end - sf] += template_scaled_sparse + + for ch, temp in enumerate(template_scaled_sparse.T): + ts_template = np.arange(sample_start, sample_end) / recording.sampling_frequency * 1000 + _ = ax.plot(ts_template, temp + 1.2 * ch * max_tr, color=f"C{spike['unit_index']}", ls="--") + + for ch, tr in enumerate(fitted_traces.T): + _ = ax.plot(ts, tr + 1.2 * ch * max_tr, color="gray", alpha=0.7) + fitted_line = ax.get_lines()[-1] + fitted_line.set_label("Fitted") -# TODO: fix this! -# def plot_overlapping_spikes(we, overlap, -# spikes, cut_out_samples=100, -# max_consecutive_spikes=3, -# sparsity=None, -# fitted_amps=None): -# recording = we.recording -# nbefore_nafter_max = max(we.nafter, we.nbefore) -# cut_out_samples = max(cut_out_samples, nbefore_nafter_max) -# spike_index = overlap[max_consecutive_spikes] -# overlap_indices = overlap[overlap != -1] -# overlapping_spikes = spikes[overlap_indices] - -# if sparsity is not None: -# unit_inds_to_channel_indices = sparsity.unit_id_to_channel_indices -# sparse_indices = np.array([], dtype="int") -# for spike in overlapping_spikes: -# sparse_indices_i = unit_inds_to_channel_indices[we.unit_ids[spike["unit_index"]]] -# sparse_indices = np.union1d(sparse_indices, sparse_indices_i) -# else: -# sparse_indices = np.unique(overlapping_spikes["channel_index"]) - -# channel_ids = recording.channel_ids[sparse_indices] - -# center_spike = spikes[spike_index]["sample_index"] -# max_delta = np.max([np.abs(center_spike - overlapping_spikes[0]["sample_index"]), -# np.abs(center_spike - overlapping_spikes[-1]["sample_index"])]) -# sf = center_spike - max_delta - cut_out_samples -# ef = center_spike + max_delta + cut_out_samples -# tr_overlap = recording.get_traces(start_frame=sf, -# end_frame=ef, -# channel_ids=channel_ids, return_scaled=True) -# ts = np.arange(sf, ef) / recording.sampling_frequency * 1000 -# max_tr = np.max(np.abs(tr_overlap)) -# fig, ax = plt.subplots() -# for ch, tr in enumerate(tr_overlap.T): -# _ = ax.plot(ts, tr + 1.2 * ch * max_tr, color="k") -# ax.text(ts[0], 1.2 * ch * max_tr - 0.3 * max_tr, f"Ch:{channel_ids[ch]}") - -# used_labels = [] -# for spike in overlapping_spikes: -# label = f"U{spike['unit_index']}" -# if label in used_labels: -# label = None -# else: -# used_labels.append(label) -# ax.axvline(spike["sample_index"] / recording.sampling_frequency * 1000, -# color=f"C{spike['unit_index']}", label=label) - -# if fitted_amps is not None: -# fitted_traces = np.zeros_like(tr_overlap) - -# all_templates = we.get_all_templates() -# for i, spike in enumerate(overlapping_spikes): -# template = all_templates[spike["unit_index"]] -# template_scaled = fitted_amps[overlap_indices[i]] * template -# template_scaled_sparse = template_scaled[:, sparse_indices] -# sample_start = spike["sample_index"] - we.nbefore -# sample_end = sample_start + template_scaled_sparse.shape[0] - -# fitted_traces[sample_start - sf: sample_end - sf] += template_scaled_sparse - -# for ch, temp in enumerate(template_scaled_sparse.T): - -# ts_template = np.arange(sample_start, sample_end) / recording.sampling_frequency * 1000 -# _ = ax.plot(ts_template, temp + 1.2 * ch * max_tr, color=f"C{spike['unit_index']}", -# ls="--") - -# for ch, tr in enumerate(fitted_traces.T): -# _ = ax.plot(ts, tr + 1.2 * ch * max_tr, color="gray", alpha=0.7) - -# fitted_line = ax.get_lines()[-1] -# fitted_line.set_label("Fitted") - - -# ax.legend() -# ax.set_title(f"Spike {spike_index} - sample {center_spike}") -# return tr_overlap, ax + ax.legend() + ax.set_title(f"Spike {spike_index} - sample {center_spike['sample_index']}") + return ax From 7bec9df5c0298c853f552989ef5e3febcf0f9470 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 28 Aug 2023 17:52:07 +0200 Subject: [PATCH 3/5] Simplify and cleanup --- .../postprocessing/amplitude_scalings.py | 484 ++++++++---------- 1 file changed, 206 insertions(+), 278 deletions(-) diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index d367ef4f22..1f7923eb05 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -7,6 +7,9 @@ from spikeinterface.core.waveform_extractor import WaveformExtractor, BaseWaveformExtractorExtension +# DEBUG = True + + class AmplitudeScalingsCalculator(BaseWaveformExtractorExtension): """ Computes amplitude scalings from WaveformExtractor. @@ -21,7 +24,6 @@ def __init__(self, waveform_extractor): self.spikes = self.waveform_extractor.sorting.to_spike_vector( extremum_channel_inds=extremum_channel_inds, use_cache=False ) - self.overlapping_mask = None def _set_params( self, @@ -30,7 +32,6 @@ def _set_params( ms_before, ms_after, handle_collisions, - max_consecutive_collisions, delta_collision_ms, ): params = dict( @@ -39,7 +40,6 @@ def _set_params( ms_before=ms_before, ms_after=ms_after, handle_collisions=handle_collisions, - max_consecutive_collisions=max_consecutive_collisions, delta_collision_ms=delta_collision_ms, ) return params @@ -63,7 +63,6 @@ def _run(self, **job_kwargs): # collisions handle_collisions = self._params["handle_collisions"] - max_consecutive_collisions = self._params["max_consecutive_collisions"] delta_collision_ms = self._params["delta_collision_ms"] delta_collision_samples = int(delta_collision_ms / 1000 * we.sampling_frequency) @@ -120,7 +119,6 @@ def _run(self, **job_kwargs): cut_out_after, return_scaled, handle_collisions, - max_consecutive_collisions, delta_collision_samples, ) processor = ChunkRecordingExecutor( @@ -133,32 +131,16 @@ def _run(self, **job_kwargs): **job_kwargs, ) out = processor.run() - (amp_scalings, overlapping_mask) = zip(*out) + (amp_scalings, collisions) = zip(*out) amp_scalings = np.concatenate(amp_scalings) - if handle_collisions > 0: - from ..core.job_tools import divide_recording_into_chunks - - overlapping_mask_corrected = [] - all_chunks = divide_recording_into_chunks(processor.recording, processor.chunk_size) - num_spikes_so_far = 0 - for i, overlapping in enumerate(overlapping_mask): - if i == 0: - continue - segment_index = all_chunks[i - 1][0] - spikes_in_segment = self.spikes[segment_slices[segment_index]] - i0 = np.searchsorted(spikes_in_segment["sample_index"], all_chunks[i - 1][1]) - i1 = np.searchsorted(spikes_in_segment["sample_index"], all_chunks[i - 1][2]) - num_spikes_so_far += i1 - i0 - overlapping_corrected = overlapping.copy() - overlapping_corrected[overlapping_corrected >= 0] += num_spikes_so_far - overlapping_mask_corrected.append(overlapping_corrected) - overlapping_mask = np.concatenate(overlapping_mask_corrected) - print(f"Found {len(overlapping_mask)} overlapping spikes") - self.overlapping_mask = overlapping_mask - else: - overlapping_mask = np.concatenate(overlapping_mask) + + collisions_dict = {} + if handle_collisions: + for collision in collisions: + collisions_dict.update(collision) self._extension_data[f"amplitude_scalings"] = amp_scalings + self._extension_data[f"collisions"] = collisions_dict def get_data(self, outputs="concatenated"): """ @@ -206,7 +188,6 @@ def compute_amplitude_scalings( ms_before=None, ms_after=None, handle_collisions=False, - max_consecutive_collisions=3, delta_collision_ms=2, load_if_exists=False, outputs="concatenated", @@ -235,10 +216,8 @@ def compute_amplitude_scalings( Whether to handle collisions between spikes. If True, the amplitude scaling of colliding spikes (defined as spikes within `delta_collision_ms` ms and with overlapping sparsity) is computed by fitting a multi-linear regression model (with `sklearn.LinearRegression`). If False, each spike is fitted independently. - max_consecutive_collisions: int, default: 3 - The maximum number of consecutive collisions to handle on each side of a spike. delta_collision_ms: float, default: 2 - The maximum time difference in ms between two spikes to be considered as colliding. + The maximum time difference in ms before and after a spike to gather colliding spikes. load_if_exists : bool, default: False Whether to load precomputed spike amplitudes, if they already exist. outputs: str, default: 'concatenated' @@ -264,7 +243,6 @@ def compute_amplitude_scalings( ms_before=ms_before, ms_after=ms_after, handle_collisions=handle_collisions, - max_consecutive_collisions=max_consecutive_collisions, delta_collision_ms=delta_collision_ms, ) sac.run(**job_kwargs) @@ -288,7 +266,6 @@ def _init_worker_amplitude_scalings( cut_out_after, return_scaled, handle_collisions, - max_consecutive_collisions, delta_collision_samples, ): # create a local dict per worker @@ -304,15 +281,15 @@ def _init_worker_amplitude_scalings( worker_ctx["return_scaled"] = return_scaled worker_ctx["unit_inds_to_channel_indices"] = unit_inds_to_channel_indices worker_ctx["handle_collisions"] = handle_collisions - worker_ctx["max_consecutive_collisions"] = max_consecutive_collisions worker_ctx["delta_collision_samples"] = delta_collision_samples if not handle_collisions: worker_ctx["margin"] = max(nbefore, nafter) else: + # in this case we extend the margin to be able to get with collisions outside the chunk margin_waveforms = max(nbefore, nafter) - max_margin_collisions = int(max_consecutive_collisions * delta_collision_samples) - worker_ctx["margin"] = max(margin_waveforms, max_margin_collisions) + max_margin_collisions = delta_collision_samples + margin_waveforms + worker_ctx["margin"] = max_margin_collisions return worker_ctx @@ -332,7 +309,6 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) margin = worker_ctx["margin"] return_scaled = worker_ctx["return_scaled"] handle_collisions = worker_ctx["handle_collisions"] - max_consecutive_collisions = worker_ctx["max_consecutive_collisions"] delta_collision_samples = worker_ctx["delta_collision_samples"] spikes_in_segment = spikes[segment_slices[segment_index]] @@ -355,17 +331,21 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) # set colliding spikes apart (if needed) if handle_collisions: - overlapping_mask = find_overlapping_mask( - local_spikes, max_consecutive_collisions, delta_collision_samples, unit_inds_to_channel_indices + # local spikes with margin! + i0_margin = np.searchsorted(spikes_in_segment["sample_index"], start_frame - left) + i1_margin = np.searchsorted(spikes_in_segment["sample_index"], end_frame + right) + local_spikes_w_margin = spikes_in_segment[i0_margin:i1_margin] + collisions = find_collisions( + local_spikes, local_spikes_w_margin, delta_collision_samples, unit_inds_to_channel_indices ) - overlapping_spike_indices = overlapping_mask[:, max_consecutive_collisions] else: - overlapping_spike_indices = np.array([], dtype=int) + collisions = {} - # get all waveforms + # compute the scaling for each spike scalings = np.zeros(len(local_spikes), dtype=float) + collisions_dict = {} for spike_index, spike in enumerate(local_spikes): - if spike_index in overlapping_spike_indices: + if spike_index in collisions.keys(): # we deal with overlapping spikes later continue unit_index = spike["unit_index"] @@ -390,14 +370,13 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) scalings[spike_index] = linregress_res[0] # deal with collisions - if len(overlapping_spike_indices) > 0: - for overlapping in overlapping_mask: - # the current spike is the one at the 'max_consecutive_collisions' position - spike_index = overlapping[max_consecutive_collisions] - overlapping_spikes = local_spikes[overlapping[overlapping >= 0]] - current_spike_index_within_overlapping = np.where(overlapping >= 0)[0] == max_consecutive_collisions + if len(collisions) > 0: + num_spikes_in_previous_segments = int( + np.sum([len(spikes[segment_slices[s]]) for s in range(segment_index)]) + ) + for spike_index, collision in collisions.items(): scaled_amps = fit_collision( - overlapping_spikes, + collision, traces_with_margin, start_frame, end_frame, @@ -409,13 +388,16 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) cut_out_before, cut_out_after, ) - # get the right amplitude scaling - scalings[spike_index] = scaled_amps[current_spike_index_within_overlapping] + # the scaling for the current spike is at index 0 + scalings[spike_index] = scaled_amps[0] + + # make collision_dict indices "absolute" by adding i0 and the cumulative number of spikes in previous segments + collisions_dict.update({spike_index + i0 + num_spikes_in_previous_segments: collision}) else: scalings = np.array([]) - overlapping_mask = np.array([], shape=(0, max_consecutive_collisions + 1)) + collisions_dict = {} - return (scalings, overlapping_mask) + return (scalings, collisions_dict) ### Collision handling ### @@ -443,111 +425,65 @@ def _are_unit_indices_overlapping(unit_inds_to_channel_indices, i, j): return False -def find_overlapping_mask(spikes, max_consecutive_spikes, delta_overlap_samples, unit_inds_to_channel_indices): +def find_collisions(spikes, spikes_w_margin, delta_collision_samples, unit_inds_to_channel_indices): """ - Finds the overlapping spikes for each spike in spikes and returns a boolean mask of shape - (n_spikes, 2 * max_consecutive_spikes + 1). + Finds the collisions between spikes. Parameters ---------- spikes: np.array An array of spikes - max_consecutive_spikes: int - The maximum number of consecutive spikes to consider - delta_overlap_samples: int + spikes_w_margin: np.array + An array of spikes within the added margin + delta_collision_samples: int The maximum number of samples between two spikes to consider them as overlapping unit_inds_to_channel_indices: dict A dictionary mapping unit indices to channel indices Returns ------- - overlapping_mask: np.array - A boolean mask of shape (n_spikes, 2 * max_consecutive_spikes + 1) where the central column (max_consecutive_spikes) - is the current spike index, while the other columns are the indices of the overlapping spikes. The first - max_consecutive_spikes columns are the pre-overlapping spikes, while the last max_consecutive_spikes columns are - the post-overlapping spikes. + collision_spikes_dict: np.array + A dictionary with collisions. The key is the index of the spike with collision, the value is an + array of overlapping spikes, including the spike itself at position 0. """ + collision_spikes_dict = {} + for spike_index, spike in enumerate(spikes): + # find the index of the spike within the spikes_w_margin + spike_index_w_margin = np.where(spikes_w_margin == spike)[0][0] - # overlapping_mask is a matrix of shape (n_spikes, 2 * max_consecutive_spikes + 1) - # the central column (max_consecutive_spikes) is the current spike index, while the other columns are the - # indices of the overlapping spikes. The first max_consecutive_spikes columns are the pre-overlapping spikes, - # while the last max_consecutive_spikes columns are the post-overlapping spikes - # Rows with all -1 are non-colliding spikes and are removed later - overlapping_mask_full = -1 * np.ones((len(spikes), 2 * max_consecutive_spikes + 1), dtype=int) - overlapping_mask_full[:, max_consecutive_spikes] = np.arange(len(spikes)) - - for i, spike in enumerate(spikes): - # find the possible spikes per and post within max_consecutive_spikes * delta_overlap_samples + # find the possible spikes per and post within delta_collision_samples consecutive_window_pre = np.searchsorted( - spikes["sample_index"], - spike["sample_index"] - max_consecutive_spikes * delta_overlap_samples, + spikes_w_margin["sample_index"], + spike["sample_index"] - delta_collision_samples, ) consecutive_window_post = np.searchsorted( - spikes["sample_index"], - spike["sample_index"] + max_consecutive_spikes * delta_overlap_samples, + spikes_w_margin["sample_index"], + spike["sample_index"] + delta_collision_samples, + ) + # exclude the spike itself (it is included in the collision_spikes by construction) + pre_possible_consecutive_spike_indices = np.arange(consecutive_window_pre, spike_index_w_margin) + post_possible_consecutive_spike_indices = np.arange(spike_index_w_margin + 1, consecutive_window_post) + possible_overlapping_spike_indices = np.concatenate( + (pre_possible_consecutive_spike_indices, post_possible_consecutive_spike_indices) ) - pre_possible_consecutive_spikes = spikes[consecutive_window_pre:i][::-1] - post_possible_consecutive_spikes = spikes[i + 1 : consecutive_window_post] - - # here we fill in the overlapping information by consecutively looping through the possible consecutive spikes - # and checking the spatial overlap and the delay with the previous overlapping spike - # pre and post are hanlded separately. Note that the pre-spikes are already sorted backwards - - # overlap_rank keeps track of the rank of consecutive collisions (i.e., rank 0 is the first, rank 1 is the second, etc.) - # this is needed because we are just considering spikes with spatial overlap, while the possible consecutive spikes - # only looked at the temporal overlap - overlap_rank = 0 - if len(pre_possible_consecutive_spikes) > 0: - for c_pre, spike_consecutive_pre in enumerate(pre_possible_consecutive_spikes[::-1]): - if _are_unit_indices_overlapping( - unit_inds_to_channel_indices, spike["unit_index"], spike_consecutive_pre["unit_index"] - ): - if ( - spikes[overlapping_mask_full[i, max_consecutive_spikes - overlap_rank]]["sample_index"] - - spike_consecutive_pre["sample_index"] - < delta_overlap_samples - ): - overlapping_mask_full[i, max_consecutive_spikes - overlap_rank - 1] = i - 1 - c_pre - overlap_rank += 1 - else: - break - # if overlap_rank > 1: - # print(f"\tHigher order pre-overlap for spike {i}!") - - overlap_rank = 0 - if len(post_possible_consecutive_spikes) > 0: - for c_post, spike_consecutive_post in enumerate(post_possible_consecutive_spikes): - if _are_unit_indices_overlapping( - unit_inds_to_channel_indices, spike["unit_index"], spike_consecutive_post["unit_index"] - ): - if ( - spike_consecutive_post["sample_index"] - - spikes[overlapping_mask_full[i, max_consecutive_spikes + overlap_rank]]["sample_index"] - < delta_overlap_samples - ): - overlapping_mask_full[i, max_consecutive_spikes + overlap_rank + 1] = i + 1 + c_post - overlap_rank += 1 - else: - break - # if overlap_rank > 1: - # print(f"\tHigher order post-overlap for spike {i}!") - - # in case no collisions were found, we set the central column to -1 so that we can easily identify the non-colliding spikes - if np.sum(overlapping_mask_full[i] != -1) == 1: - overlapping_mask_full[i, max_consecutive_spikes] = -1 - - # only return rows with collisions - overlapping_inds = [] - for i, overlapping in enumerate(overlapping_mask_full): - if np.any(overlapping >= 0): - overlapping_inds.append(i) - overlapping_mask = overlapping_mask_full[overlapping_inds] - - return overlapping_mask + + # find the overlapping spikes in space as well + for possible_overlapping_spike_index in possible_overlapping_spike_indices: + if _are_unit_indices_overlapping( + unit_inds_to_channel_indices, + spike["unit_index"], + spikes_w_margin[possible_overlapping_spike_index]["unit_index"], + ): + if spike_index not in collision_spikes_dict: + collision_spikes_dict[spike_index] = np.array([spike]) + collision_spikes_dict[spike_index] = np.concatenate( + (collision_spikes_dict[spike_index], [spikes_w_margin[possible_overlapping_spike_index]]) + ) + return collision_spikes_dict def fit_collision( - overlapping_spikes, + collision, traces_with_margin, start_frame, end_frame, @@ -558,15 +494,16 @@ def fit_collision( unit_inds_to_channel_indices, cut_out_before, cut_out_after, - debug=True, ): """ Compute the best fit for a collision between a spike and its overlapping spikes. + The function first cuts out the traces around the spike and its overlapping spikes, then + fits a multi-linear regression model to the traces using the centered templates as predictors. Parameters ---------- - overlapping_spikes: np.ndarray - A numpy array of shape (n_overlapping_spikes, ) containing the overlapping spikes (spike_dtype). + collision: np.ndarray + A numpy array of shape (n_colliding_spikes, ) containing the colliding spikes (spike_dtype). traces_with_margin: np.ndarray A numpy array of shape (n_samples, n_channels) containing the traces with a margin. start_frame: int @@ -591,30 +528,30 @@ def fit_collision( Returns ------- np.ndarray - The fitted scaling factors for the overlapping spikes. + The fitted scaling factors for the colliding spikes. """ from sklearn.linear_model import LinearRegression - sample_first_centered = overlapping_spikes[0]["sample_index"] - start_frame - left - sample_last_centered = overlapping_spikes[-1]["sample_index"] - start_frame - left + # make center of the spike externally + sample_first_centered = np.min(collision["sample_index"]) - (start_frame - left) + sample_last_centered = np.max(collision["sample_index"]) - (start_frame - left) # construct sparsity as union between units' sparsity sparse_indices = np.array([], dtype="int") - for spike in overlapping_spikes: + for spike in collision: sparse_indices_i = unit_inds_to_channel_indices[spike["unit_index"]] sparse_indices = np.union1d(sparse_indices, sparse_indices_i) - # TODO: check alignment!!! local_waveform_start = max(0, sample_first_centered - cut_out_before) local_waveform_end = min(traces_with_margin.shape[0], sample_last_centered + cut_out_after) local_waveform = traces_with_margin[local_waveform_start:local_waveform_end, sparse_indices] y = local_waveform.T.flatten() - X = np.zeros((len(y), len(overlapping_spikes))) - for i, spike in enumerate(overlapping_spikes): + X = np.zeros((len(y), len(collision))) + for i, spike in enumerate(collision): full_template = np.zeros_like(local_waveform) # center wrt cutout traces - sample_centered = spike["sample_index"] - start_frame - left - local_waveform_start + sample_centered = spike["sample_index"] - (start_frame - left) - local_waveform_start template = all_templates[spike["unit_index"]][:, sparse_indices] template_cut = template[nbefore - cut_out_before : nbefore + cut_out_after] # deal with borders @@ -626,136 +563,127 @@ def fit_collision( full_template[sample_centered - cut_out_before : sample_centered + cut_out_after] = template_cut X[:, i] = full_template.T.flatten() - if debug: - import matplotlib.pyplot as plt - - fig, ax = plt.subplots() - max_tr = np.max(np.abs(local_waveform)) - - _ = ax.plot(y, color="k") - - for i, spike in enumerate(overlapping_spikes): - _ = ax.plot(X[:, i], color=f"C{i}", alpha=0.5) - plt.show() - reg = LinearRegression().fit(X, y) scalings = reg.coef_ return scalings -def plot_collisions(we, sparsity=None, num_collisions=None): - """ - Plot the fitting of collision spikes. - - Parameters - ---------- - we : WaveformExtractor - The WaveformExtractor object. - sparsity : ChannelSparsity, default=None - The ChannelSparsity. If None, only main channels are plotted. - num_collisions : int, default=None - Number of collisions to plot. If None, all collisions are plotted. - """ - assert we.is_extension("amplitude_scalings"), "Could not find amplitude scalings extension!" - sac = we.load_extension("amplitude_scalings") - handle_collisions = sac._params["handle_collisions"] - assert handle_collisions, "Amplitude scalings was run without handling collisions!" - scalings = sac.get_data() - - overlapping_mask = sac.overlapping_mask - num_collisions = num_collisions or len(overlapping_mask) - spikes = sac.spikes - max_consecutive_collisions = sac._params["max_consecutive_collisions"] - - for i in range(num_collisions): - ax = _plot_one_collision( - we, overlapping_mask[i], spikes, scalings=scalings, max_consecutive_collisions=max_consecutive_collisions - ) - - -def _plot_one_collision( - we, - overlap, - spikes, - scalings=None, - sparsity=None, - cut_out_samples=100, - max_consecutive_collisions=3, -): - import matplotlib.pyplot as plt - - recording = we.recording - nbefore_nafter_max = max(we.nafter, we.nbefore) - cut_out_samples = max(cut_out_samples, nbefore_nafter_max) - spike_index = overlap[max_consecutive_collisions] - overlap_indices = overlap[overlap != -1] - overlapping_spikes = spikes[overlap_indices] - - if sparsity is not None: - unit_inds_to_channel_indices = sparsity.unit_id_to_channel_indices - sparse_indices = np.array([], dtype="int") - for spike in overlapping_spikes: - sparse_indices_i = unit_inds_to_channel_indices[we.unit_ids[spike["unit_index"]]] - sparse_indices = np.union1d(sparse_indices, sparse_indices_i) - else: - sparse_indices = np.unique(overlapping_spikes["channel_index"]) - - channel_ids = recording.channel_ids[sparse_indices] - - center_spike = spikes[spike_index] - max_delta = np.max( - [ - np.abs(center_spike["sample_index"] - overlapping_spikes[0]["sample_index"]), - np.abs(center_spike["sample_index"] - overlapping_spikes[-1]["sample_index"]), - ] - ) - sf = max(0, center_spike["sample_index"] - max_delta - cut_out_samples) - ef = min( - center_spike["sample_index"] + max_delta + cut_out_samples, - recording.get_num_samples(segment_index=center_spike["segment_index"]), - ) - tr_overlap = recording.get_traces(start_frame=sf, end_frame=ef, channel_ids=channel_ids, return_scaled=True) - ts = np.arange(sf, ef) / recording.sampling_frequency * 1000 - max_tr = np.max(np.abs(tr_overlap)) - fig, ax = plt.subplots() - for ch, tr in enumerate(tr_overlap.T): - _ = ax.plot(ts, tr + 1.2 * ch * max_tr, color="k") - ax.text(ts[0], 1.2 * ch * max_tr - 0.3 * max_tr, f"Ch:{channel_ids[ch]}") - - used_labels = [] - for spike in overlapping_spikes: - label = f"U{spike['unit_index']}" - if label in used_labels: - label = None - else: - used_labels.append(label) - ax.axvline( - spike["sample_index"] / recording.sampling_frequency * 1000, color=f"C{spike['unit_index']}", label=label - ) - - if scalings is not None: - fitted_traces = np.zeros_like(tr_overlap) - - all_templates = we.get_all_templates() - for i, spike in enumerate(overlapping_spikes): - template = all_templates[spike["unit_index"]] - template_scaled = scalings[overlap_indices[i]] * template - template_scaled_sparse = template_scaled[:, sparse_indices] - sample_start = spike["sample_index"] - we.nbefore - sample_end = sample_start + template_scaled_sparse.shape[0] - - fitted_traces[sample_start - sf : sample_end - sf] += template_scaled_sparse - - for ch, temp in enumerate(template_scaled_sparse.T): - ts_template = np.arange(sample_start, sample_end) / recording.sampling_frequency * 1000 - _ = ax.plot(ts_template, temp + 1.2 * ch * max_tr, color=f"C{spike['unit_index']}", ls="--") - - for ch, tr in enumerate(fitted_traces.T): - _ = ax.plot(ts, tr + 1.2 * ch * max_tr, color="gray", alpha=0.7) - - fitted_line = ax.get_lines()[-1] - fitted_line.set_label("Fitted") - - ax.legend() - ax.set_title(f"Spike {spike_index} - sample {center_spike['sample_index']}") - return ax +# uncomment for debugging +# def plot_collisions(we, sparsity=None, num_collisions=None): +# """ +# Plot the fitting of collision spikes. + +# Parameters +# ---------- +# we : WaveformExtractor +# The WaveformExtractor object. +# sparsity : ChannelSparsity, default=None +# The ChannelSparsity. If None, only main channels are plotted. +# num_collisions : int, default=None +# Number of collisions to plot. If None, all collisions are plotted. +# """ +# assert we.is_extension("amplitude_scalings"), "Could not find amplitude scalings extension!" +# sac = we.load_extension("amplitude_scalings") +# handle_collisions = sac._params["handle_collisions"] +# assert handle_collisions, "Amplitude scalings was run without handling collisions!" +# scalings = sac.get_data() + +# # overlapping_mask = sac.overlapping_mask +# # num_collisions = num_collisions or len(overlapping_mask) +# spikes = sac.spikes +# collisions = sac._extension_data[f"collisions"] +# collision_keys = list(collisions.keys()) +# num_collisions = num_collisions or len(collisions) +# num_collisions = min(num_collisions, len(collisions)) + +# for i in range(num_collisions): +# overlapping_spikes = collisions[collision_keys[i]] +# ax = _plot_one_collision( +# we, collision_keys[i], overlapping_spikes, spikes, scalings=scalings, sparsity=sparsity +# ) + + +# def _plot_one_collision( +# we, +# spike_index, +# overlapping_spikes, +# spikes, +# scalings=None, +# sparsity=None, +# cut_out_samples=100, +# ): +# import matplotlib.pyplot as plt + +# recording = we.recording +# nbefore_nafter_max = max(we.nafter, we.nbefore) +# cut_out_samples = max(cut_out_samples, nbefore_nafter_max) + +# if sparsity is not None: +# unit_inds_to_channel_indices = sparsity.unit_id_to_channel_indices +# sparse_indices = np.array([], dtype="int") +# for spike in overlapping_spikes: +# sparse_indices_i = unit_inds_to_channel_indices[we.unit_ids[spike["unit_index"]]] +# sparse_indices = np.union1d(sparse_indices, sparse_indices_i) +# else: +# sparse_indices = np.unique(overlapping_spikes["channel_index"]) + +# channel_ids = recording.channel_ids[sparse_indices] + +# center_spike = overlapping_spikes[0] +# max_delta = np.max( +# [ +# np.abs(center_spike["sample_index"] - np.min(overlapping_spikes[1:]["sample_index"])), +# np.abs(center_spike["sample_index"] - np.max(overlapping_spikes[1:]["sample_index"])), +# ] +# ) +# sf = max(0, center_spike["sample_index"] - max_delta - cut_out_samples) +# ef = min( +# center_spike["sample_index"] + max_delta + cut_out_samples, +# recording.get_num_samples(segment_index=center_spike["segment_index"]), +# ) +# tr_overlap = recording.get_traces(start_frame=sf, end_frame=ef, channel_ids=channel_ids, return_scaled=True) +# ts = np.arange(sf, ef) / recording.sampling_frequency * 1000 +# max_tr = np.max(np.abs(tr_overlap)) +# fig, ax = plt.subplots() +# for ch, tr in enumerate(tr_overlap.T): +# _ = ax.plot(ts, tr + 1.2 * ch * max_tr, color="k") +# ax.text(ts[0], 1.2 * ch * max_tr - 0.3 * max_tr, f"Ch:{channel_ids[ch]}") + +# used_labels = [] +# for i, spike in enumerate(overlapping_spikes): +# label = f"U{spike['unit_index']}" +# if label in used_labels: +# label = None +# else: +# used_labels.append(label) +# ax.axvline( +# spike["sample_index"] / recording.sampling_frequency * 1000, color=f"C{spike['unit_index']}", label=label +# ) + +# if scalings is not None: +# fitted_traces = np.zeros_like(tr_overlap) + +# all_templates = we.get_all_templates() +# for i, spike in enumerate(overlapping_spikes): +# template = all_templates[spike["unit_index"]] +# overlap_index = np.where(spikes == spike)[0][0] +# template_scaled = scalings[overlap_index] * template +# template_scaled_sparse = template_scaled[:, sparse_indices] +# sample_start = spike["sample_index"] - we.nbefore +# sample_end = sample_start + template_scaled_sparse.shape[0] + +# fitted_traces[sample_start - sf : sample_end - sf] += template_scaled_sparse + +# for ch, temp in enumerate(template_scaled_sparse.T): +# ts_template = np.arange(sample_start, sample_end) / recording.sampling_frequency * 1000 +# _ = ax.plot(ts_template, temp + 1.2 * ch * max_tr, color=f"C{spike['unit_index']}", ls="--") + +# for ch, tr in enumerate(fitted_traces.T): +# _ = ax.plot(ts, tr + 1.2 * ch * max_tr, color="gray", alpha=0.7) + +# fitted_line = ax.get_lines()[-1] +# fitted_line.set_label("Fitted") + +# ax.legend() +# ax.set_title(f"Spike {spike_index} - sample {center_spike['sample_index']}") +# return ax From a797aa33c561871b10b4a441985d89546e8ebc2e Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 30 Aug 2023 15:55:16 +0200 Subject: [PATCH 4/5] Improve debug plots and handle_collisions=True by default --- .../postprocessing/amplitude_scalings.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 1f7923eb05..a9b3898388 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -187,7 +187,7 @@ def compute_amplitude_scalings( max_dense_channels=16, ms_before=None, ms_after=None, - handle_collisions=False, + handle_collisions=True, delta_collision_ms=2, load_if_exists=False, outputs="concatenated", @@ -212,7 +212,7 @@ def compute_amplitude_scalings( ms_after : float, default: None The cut out to apply after the spike peak to extract local waveforms. If None, the WaveformExtractor ms_after is used. - handle_collisions: bool, default: False + handle_collisions: bool, default: True Whether to handle collisions between spikes. If True, the amplitude scaling of colliding spikes (defined as spikes within `delta_collision_ms` ms and with overlapping sparsity) is computed by fitting a multi-linear regression model (with `sklearn.LinearRegression`). If False, each spike is fitted independently. @@ -598,12 +598,12 @@ def fit_collision( # for i in range(num_collisions): # overlapping_spikes = collisions[collision_keys[i]] -# ax = _plot_one_collision( +# ax = plot_one_collision( # we, collision_keys[i], overlapping_spikes, spikes, scalings=scalings, sparsity=sparsity # ) -# def _plot_one_collision( +# def plot_one_collision( # we, # spike_index, # overlapping_spikes, @@ -611,9 +611,13 @@ def fit_collision( # scalings=None, # sparsity=None, # cut_out_samples=100, +# ax=None # ): # import matplotlib.pyplot as plt +# if ax is None: +# fig, ax = plt.subplots() + # recording = we.recording # nbefore_nafter_max = max(we.nafter, we.nbefore) # cut_out_samples = max(cut_out_samples, nbefore_nafter_max) @@ -644,7 +648,7 @@ def fit_collision( # tr_overlap = recording.get_traces(start_frame=sf, end_frame=ef, channel_ids=channel_ids, return_scaled=True) # ts = np.arange(sf, ef) / recording.sampling_frequency * 1000 # max_tr = np.max(np.abs(tr_overlap)) -# fig, ax = plt.subplots() + # for ch, tr in enumerate(tr_overlap.T): # _ = ax.plot(ts, tr + 1.2 * ch * max_tr, color="k") # ax.text(ts[0], 1.2 * ch * max_tr - 0.3 * max_tr, f"Ch:{channel_ids[ch]}") From 287e8af9621385d4fa835be6356b7695993cdc16 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 4 Sep 2023 16:46:18 +0200 Subject: [PATCH 5/5] Fix tests --- src/spikeinterface/postprocessing/amplitude_scalings.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 0dd2587fba..5a0148c5c4 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -24,6 +24,7 @@ def __init__(self, waveform_extractor): self.spikes = self.waveform_extractor.sorting.to_spike_vector( extremum_channel_inds=extremum_channel_inds, use_cache=False ) + self.collisions = None def _set_params( self, @@ -138,9 +139,11 @@ def _run(self, **job_kwargs): if handle_collisions: for collision in collisions: collisions_dict.update(collision) + self.collisions = collisions_dict + # Note: collisions are note in _extension_data because they are not pickable. We only store the indices + self._extension_data["collisions"] = np.array(list(collisions_dict.keys())) - self._extension_data[f"amplitude_scalings"] = amp_scalings - self._extension_data[f"collisions"] = collisions_dict + self._extension_data["amplitude_scalings"] = amp_scalings def get_data(self, outputs="concatenated"): """