diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 5a0148c5c4..4dab68fdf8 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -90,10 +90,7 @@ def _run(self, **job_kwargs): if self._params["max_dense_channels"] is not None: assert recording.get_num_channels() <= self._params["max_dense_channels"], "" sparsity = ChannelSparsity.create_dense(we) - sparsity_inds = sparsity.unit_id_to_channel_indices - - # easier to use in chunk function as spikes use unit_index instead o id - unit_inds_to_channel_indices = {unit_ind: sparsity_inds[unit_id] for unit_ind, unit_id in enumerate(unit_ids)} + sparsity_mask = sparsity.mask all_templates = we.get_all_templates() # precompute segment slice @@ -113,7 +110,7 @@ def _run(self, **job_kwargs): self.spikes, all_templates, segment_slices, - unit_inds_to_channel_indices, + sparsity_mask, nbefore, nafter, cut_out_before, @@ -262,7 +259,7 @@ def _init_worker_amplitude_scalings( spikes, all_templates, segment_slices, - unit_inds_to_channel_indices, + sparsity_mask, nbefore, nafter, cut_out_before, @@ -282,7 +279,7 @@ def _init_worker_amplitude_scalings( worker_ctx["cut_out_before"] = cut_out_before worker_ctx["cut_out_after"] = cut_out_after worker_ctx["return_scaled"] = return_scaled - worker_ctx["unit_inds_to_channel_indices"] = unit_inds_to_channel_indices + worker_ctx["sparsity_mask"] = sparsity_mask worker_ctx["handle_collisions"] = handle_collisions worker_ctx["delta_collision_samples"] = delta_collision_samples @@ -306,7 +303,7 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) recording = worker_ctx["recording"] all_templates = worker_ctx["all_templates"] segment_slices = worker_ctx["segment_slices"] - unit_inds_to_channel_indices = worker_ctx["unit_inds_to_channel_indices"] + sparsity_mask = worker_ctx["sparsity_mask"] nbefore = worker_ctx["nbefore"] cut_out_before = worker_ctx["cut_out_before"] cut_out_after = worker_ctx["cut_out_after"] @@ -339,7 +336,7 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) i1_margin = np.searchsorted(spikes_in_segment["sample_index"], end_frame + right) local_spikes_w_margin = spikes_in_segment[i0_margin:i1_margin] collisions_local = find_collisions( - local_spikes, local_spikes_w_margin, delta_collision_samples, unit_inds_to_channel_indices + local_spikes, local_spikes_w_margin, delta_collision_samples, sparsity_mask ) else: collisions_local = {} @@ -354,7 +351,7 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) continue unit_index = spike["unit_index"] sample_index = spike["sample_index"] - sparse_indices = unit_inds_to_channel_indices[unit_index] + sparse_indices = sparsity_mask[unit_index] template = all_templates[unit_index][:, sparse_indices] template = template[nbefore - cut_out_before : nbefore + cut_out_after] sample_centered = sample_index - start_frame @@ -393,7 +390,7 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) right, nbefore, all_templates, - unit_inds_to_channel_indices, + sparsity_mask, cut_out_before, cut_out_after, ) @@ -410,14 +407,14 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) ### Collision handling ### -def _are_unit_indices_overlapping(unit_inds_to_channel_indices, i, j): +def _are_unit_indices_overlapping(sparsity_mask, i, j): """ Returns True if the unit indices i and j are overlapping, False otherwise Parameters ---------- - unit_inds_to_channel_indices: dict - A dictionary mapping unit indices to channel indices + sparsity_mask: boolean mask + The sparsity mask i: int The first unit index j: int @@ -428,13 +425,13 @@ def _are_unit_indices_overlapping(unit_inds_to_channel_indices, i, j): bool True if the unit indices i and j are overlapping, False otherwise """ - if len(np.intersect1d(unit_inds_to_channel_indices[i], unit_inds_to_channel_indices[j])) > 0: + if np.sum(np.logical_and(sparsity_mask[i], sparsity_mask[j])) > 0: return True else: return False -def find_collisions(spikes, spikes_w_margin, delta_collision_samples, unit_inds_to_channel_indices): +def find_collisions(spikes, spikes_w_margin, delta_collision_samples, sparsity_mask): """ Finds the collisions between spikes. @@ -446,8 +443,8 @@ def find_collisions(spikes, spikes_w_margin, delta_collision_samples, unit_inds_ An array of spikes within the added margin delta_collision_samples: int The maximum number of samples between two spikes to consider them as overlapping - unit_inds_to_channel_indices: dict - A dictionary mapping unit indices to channel indices + sparsity_mask: boolean mask + The sparsity mask Returns ------- @@ -480,7 +477,7 @@ def find_collisions(spikes, spikes_w_margin, delta_collision_samples, unit_inds_ # find the overlapping spikes in space as well for possible_overlapping_spike_index in possible_overlapping_spike_indices: if _are_unit_indices_overlapping( - unit_inds_to_channel_indices, + sparsity_mask, spike["unit_index"], spikes_w_margin[possible_overlapping_spike_index]["unit_index"], ): @@ -501,7 +498,7 @@ def fit_collision( right, nbefore, all_templates, - unit_inds_to_channel_indices, + sparsity_mask, cut_out_before, cut_out_after, ): @@ -528,8 +525,8 @@ def fit_collision( The number of samples before the spike to consider for the fit. all_templates: np.ndarray A numpy array of shape (n_units, n_samples, n_channels) containing the templates. - unit_inds_to_channel_indices: dict - A dictionary mapping unit indices to channel indices. + sparsity_mask: boolean mask + The sparsity mask cut_out_before: int The number of samples to cut out before the spike. cut_out_after: int @@ -547,14 +544,15 @@ def fit_collision( sample_last_centered = np.max(collision["sample_index"]) - (start_frame - left) # construct sparsity as union between units' sparsity - sparse_indices = np.array([], dtype="int") + sparse_indices = np.zeros(sparsity_mask.shape[1], dtype="int") for spike in collision: - sparse_indices_i = unit_inds_to_channel_indices[spike["unit_index"]] - sparse_indices = np.union1d(sparse_indices, sparse_indices_i) + sparse_indices_i = sparsity_mask[spike["unit_index"]] + sparse_indices = np.logical_or(sparse_indices, sparse_indices_i) local_waveform_start = max(0, sample_first_centered - cut_out_before) local_waveform_end = min(traces_with_margin.shape[0], sample_last_centered + cut_out_after) local_waveform = traces_with_margin[local_waveform_start:local_waveform_end, sparse_indices] + num_samples_local_waveform = local_waveform.shape[0] y = local_waveform.T.flatten() X = np.zeros((len(y), len(collision))) @@ -567,8 +565,10 @@ def fit_collision( # deal with borders if sample_centered - cut_out_before < 0: full_template[: sample_centered + cut_out_after] = template_cut[cut_out_before - sample_centered :] - elif sample_centered + cut_out_after > end_frame + right: - full_template[sample_centered - cut_out_before :] = template_cut[: -cut_out_after - (end_frame + right)] + elif sample_centered + cut_out_after > num_samples_local_waveform: + full_template[sample_centered - cut_out_before :] = template_cut[ + : -(cut_out_after + sample_centered - num_samples_local_waveform) + ] else: full_template[sample_centered - cut_out_before : sample_centered + cut_out_after] = template_cut X[:, i] = full_template.T.flatten()