Skip to content

Commit

Permalink
Use sparsity mask and handle right border correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
alejoe91 committed Sep 19, 2023
1 parent 77523e1 commit 12fd197
Showing 1 changed file with 27 additions and 27 deletions.
54 changes: 27 additions & 27 deletions src/spikeinterface/postprocessing/amplitude_scalings.py
Original file line number Diff line number Diff line change
@@ -90,10 +90,7 @@ def _run(self, **job_kwargs):
if self._params["max_dense_channels"] is not None:
assert recording.get_num_channels() <= self._params["max_dense_channels"], ""
sparsity = ChannelSparsity.create_dense(we)
sparsity_inds = sparsity.unit_id_to_channel_indices

# easier to use in chunk function as spikes use unit_index instead o id
unit_inds_to_channel_indices = {unit_ind: sparsity_inds[unit_id] for unit_ind, unit_id in enumerate(unit_ids)}
sparsity_mask = sparsity.mask
all_templates = we.get_all_templates()

# precompute segment slice
@@ -113,7 +110,7 @@ def _run(self, **job_kwargs):
self.spikes,
all_templates,
segment_slices,
unit_inds_to_channel_indices,
sparsity_mask,
nbefore,
nafter,
cut_out_before,
@@ -262,7 +259,7 @@ def _init_worker_amplitude_scalings(
spikes,
all_templates,
segment_slices,
unit_inds_to_channel_indices,
sparsity_mask,
nbefore,
nafter,
cut_out_before,
@@ -282,7 +279,7 @@ def _init_worker_amplitude_scalings(
worker_ctx["cut_out_before"] = cut_out_before
worker_ctx["cut_out_after"] = cut_out_after
worker_ctx["return_scaled"] = return_scaled
worker_ctx["unit_inds_to_channel_indices"] = unit_inds_to_channel_indices
worker_ctx["sparsity_mask"] = sparsity_mask
worker_ctx["handle_collisions"] = handle_collisions
worker_ctx["delta_collision_samples"] = delta_collision_samples

@@ -306,7 +303,7 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx)
recording = worker_ctx["recording"]
all_templates = worker_ctx["all_templates"]
segment_slices = worker_ctx["segment_slices"]
unit_inds_to_channel_indices = worker_ctx["unit_inds_to_channel_indices"]
sparsity_mask = worker_ctx["sparsity_mask"]
nbefore = worker_ctx["nbefore"]
cut_out_before = worker_ctx["cut_out_before"]
cut_out_after = worker_ctx["cut_out_after"]
@@ -339,7 +336,7 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx)
i1_margin = np.searchsorted(spikes_in_segment["sample_index"], end_frame + right)
local_spikes_w_margin = spikes_in_segment[i0_margin:i1_margin]
collisions_local = find_collisions(
local_spikes, local_spikes_w_margin, delta_collision_samples, unit_inds_to_channel_indices
local_spikes, local_spikes_w_margin, delta_collision_samples, sparsity_mask
)
else:
collisions_local = {}
@@ -354,7 +351,7 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx)
continue
unit_index = spike["unit_index"]
sample_index = spike["sample_index"]
sparse_indices = unit_inds_to_channel_indices[unit_index]
sparse_indices = sparsity_mask[unit_index]
template = all_templates[unit_index][:, sparse_indices]
template = template[nbefore - cut_out_before : nbefore + cut_out_after]
sample_centered = sample_index - start_frame
@@ -393,7 +390,7 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx)
right,
nbefore,
all_templates,
unit_inds_to_channel_indices,
sparsity_mask,
cut_out_before,
cut_out_after,
)
@@ -410,14 +407,14 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx)


### Collision handling ###
def _are_unit_indices_overlapping(unit_inds_to_channel_indices, i, j):
def _are_unit_indices_overlapping(sparsity_mask, i, j):
"""
Returns True if the unit indices i and j are overlapping, False otherwise
Parameters
----------
unit_inds_to_channel_indices: dict
A dictionary mapping unit indices to channel indices
sparsity_mask: boolean mask
The sparsity mask
i: int
The first unit index
j: int
@@ -428,13 +425,13 @@ def _are_unit_indices_overlapping(unit_inds_to_channel_indices, i, j):
bool
True if the unit indices i and j are overlapping, False otherwise
"""
if len(np.intersect1d(unit_inds_to_channel_indices[i], unit_inds_to_channel_indices[j])) > 0:
if np.sum(np.logical_and(sparsity_mask[i], sparsity_mask[j])) > 0:
return True
else:
return False


def find_collisions(spikes, spikes_w_margin, delta_collision_samples, unit_inds_to_channel_indices):
def find_collisions(spikes, spikes_w_margin, delta_collision_samples, sparsity_mask):
"""
Finds the collisions between spikes.
@@ -446,8 +443,8 @@ def find_collisions(spikes, spikes_w_margin, delta_collision_samples, unit_inds_
An array of spikes within the added margin
delta_collision_samples: int
The maximum number of samples between two spikes to consider them as overlapping
unit_inds_to_channel_indices: dict
A dictionary mapping unit indices to channel indices
sparsity_mask: boolean mask
The sparsity mask
Returns
-------
@@ -480,7 +477,7 @@ def find_collisions(spikes, spikes_w_margin, delta_collision_samples, unit_inds_
# find the overlapping spikes in space as well
for possible_overlapping_spike_index in possible_overlapping_spike_indices:
if _are_unit_indices_overlapping(
unit_inds_to_channel_indices,
sparsity_mask,
spike["unit_index"],
spikes_w_margin[possible_overlapping_spike_index]["unit_index"],
):
@@ -501,7 +498,7 @@ def fit_collision(
right,
nbefore,
all_templates,
unit_inds_to_channel_indices,
sparsity_mask,
cut_out_before,
cut_out_after,
):
@@ -528,8 +525,8 @@ def fit_collision(
The number of samples before the spike to consider for the fit.
all_templates: np.ndarray
A numpy array of shape (n_units, n_samples, n_channels) containing the templates.
unit_inds_to_channel_indices: dict
A dictionary mapping unit indices to channel indices.
sparsity_mask: boolean mask
The sparsity mask
cut_out_before: int
The number of samples to cut out before the spike.
cut_out_after: int
@@ -547,14 +544,15 @@ def fit_collision(
sample_last_centered = np.max(collision["sample_index"]) - (start_frame - left)

# construct sparsity as union between units' sparsity
sparse_indices = np.array([], dtype="int")
sparse_indices = np.zeros(sparsity_mask.shape[1], dtype="int")
for spike in collision:
sparse_indices_i = unit_inds_to_channel_indices[spike["unit_index"]]
sparse_indices = np.union1d(sparse_indices, sparse_indices_i)
sparse_indices_i = sparsity_mask[spike["unit_index"]]
sparse_indices = np.logical_or(sparse_indices, sparse_indices_i)

local_waveform_start = max(0, sample_first_centered - cut_out_before)
local_waveform_end = min(traces_with_margin.shape[0], sample_last_centered + cut_out_after)
local_waveform = traces_with_margin[local_waveform_start:local_waveform_end, sparse_indices]
num_samples_local_waveform = local_waveform.shape[0]

y = local_waveform.T.flatten()
X = np.zeros((len(y), len(collision)))
@@ -567,8 +565,10 @@ def fit_collision(
# deal with borders
if sample_centered - cut_out_before < 0:
full_template[: sample_centered + cut_out_after] = template_cut[cut_out_before - sample_centered :]
elif sample_centered + cut_out_after > end_frame + right:
full_template[sample_centered - cut_out_before :] = template_cut[: -cut_out_after - (end_frame + right)]
elif sample_centered + cut_out_after > num_samples_local_waveform:
full_template[sample_centered - cut_out_before :] = template_cut[
: -(cut_out_after + sample_centered - num_samples_local_waveform)
]
else:
full_template[sample_centered - cut_out_before : sample_centered + cut_out_after] = template_cut
X[:, i] = full_template.T.flatten()

0 comments on commit 12fd197

Please sign in to comment.